新增tests;新增管道健康预测方法;更新复合查询函数输出格式

This commit is contained in:
JIANG
2025-12-15 18:31:29 +08:00
parent ea33fc270d
commit 9b5707841b
7 changed files with 261 additions and 36 deletions

3
api_ex/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .Fdataclean import *
from .Pdataclean import *
from .pipeline_health_analyzer import *

Binary file not shown.

View File

@@ -0,0 +1,140 @@
import os
import joblib
import pandas as pd
import numpy as np
from sksurv.ensemble import RandomSurvivalForest
import matplotlib.pyplot as plt
class PipelineHealthAnalyzer:
"""
管道健康分析器类,使用随机生存森林模型预测管道的生存概率。
该类封装了模型加载和预测功能,便于在其他项目中复用。
模型基于4个特征进行生存分析预测材料、直径、流速、压力。
使用前需确保安装依赖joblib, pandas, numpy, scikit-survival, matplotlib。
"""
def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'):
"""
初始化分析器,加载预训练的随机生存森林模型。
:param model_path: 模型文件的路径(默认为相对路径 'model/my_survival_forest_model_quxi.joblib')。
:raises FileNotFoundError: 如果模型文件不存在。
:raises Exception: 如果模型加载失败。
"""
# 确保 model 目录存在
model_dir = os.path.dirname(model_path)
if model_dir and not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件未找到: {model_path}")
try:
self.rsf = joblib.load(model_path)
self.features = [
'Material', 'Diameter', 'Flow Velocity',
'Pressure', # 'Temperature', 'Precipitation',
# 'Location', 'Structural Defects', 'Functional Defects'
]
except Exception as e:
raise Exception(f"加载模型时出错: {str(e)}")
def predict_survival(self, data: pd.DataFrame) -> list:
"""
基于输入数据预测生存函数。
:param data: pandas DataFrame包含4个必需特征列。数据应为数值型或可转换为数值型。
:return: 生存函数列表每个元素为一个生存函数对象包含时间点x和生存概率y
:raises ValueError: 如果数据缺少必需特征或格式不正确。
"""
# 检查必需特征是否存在
missing_features = [feat for feat in self.features if feat not in data.columns]
if missing_features:
raise ValueError(f"数据缺少必需特征: {missing_features}")
# 提取特征数据
try:
x_test = data[self.features].astype(float) # 确保数值型
except ValueError as e:
raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}")
# 进行预测
survival_functions = self.rsf.predict_survival_function(x_test)
return list(survival_functions)
def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True):
"""
可视化生存函数,生成生存概率图表。
:param survival_functions: predict_survival返回的生存函数列表。
:param save_path: 可选,保存图表的路径(.png格式。如果为None则不保存。
:param show_plot: 是否显示图表(在交互环境中)。
"""
plt.figure(figsize=(10, 6))
for i, sf in enumerate(survival_functions):
plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}')
plt.xlabel("时间(年)")
plt.ylabel("生存概率")
plt.title("管道生存概率预测")
plt.legend()
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"图表已保存到: {save_path}")
if show_plot:
plt.show()
else:
plt.close()
# 调用说明示例
"""
在其他项目中使用PipelineHealthAnalyzer类的步骤
1. 安装依赖在requirements.txt中添加
joblib==1.5.0
pandas==2.2.3
numpy==2.0.2
scikit-survival==0.23.1
matplotlib==3.9.4
2. 导入类:
from pipeline_health_analyzer import PipelineHealthAnalyzer
3. 初始化分析器(替换为实际模型路径):
analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib')
4. 准备数据pandas DataFrame包含9个特征列
import pandas as pd
data = pd.DataFrame({
'Material': [1, 2], # 示例数据
'Diameter': [100, 150],
'Flow Velocity': [1.5, 2.0],
'Pressure': [50, 60],
'Temperature': [20, 25],
'Precipitation': [0.1, 0.2],
'Location': [1, 2],
'Structural Defects': [0, 1],
'Functional Defects': [0, 0]
})
5. 进行预测:
survival_funcs = analyzer.predict_survival(data)
6. 查看结果(每个样本的生存概率随时间变化):
for i, sf in enumerate(survival_funcs):
print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...")
7. 可视化(可选):
analyzer.plot_survival(survival_funcs, save_path='survival_plot.png')
注意:
- 数据格式必须匹配特征列表,特征值为数值型。
- 模型文件需从原项目复制或重新训练。
- 如果需要自定义特征或模型参数可修改类中的features列表或继承此类。
"""

0
tests/__init__.py Normal file
View File

14
tests/conftest.py Normal file
View File

@@ -0,0 +1,14 @@
import pytest
import sys
import os
# 自动添加项目根目录到路径(处理项目结构)
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
def run_this_test(test_file):
"""自定义函数运行单个测试文件类似pytest"""
# 提取测试文件名(无扩展名)
test_name = os.path.splitext(os.path.basename(test_file))[0]
# 使用pytest运行自动处理导入
pytest.main([test_file, "-v"])

View File

@@ -0,0 +1,62 @@
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
def test_pipeline_health_analyzer():
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
analyzer = PipelineHealthAnalyzer(
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
)
# 创建示例输入数据9个样本
import pandas as pd
import time
base_data = pd.DataFrame(
{
"Material": [7, 11, 7, 7, 7, 7, 7, 7, 7],
"Diameter": [6, 6, 3, 2, 3, 3, 2, 2, 2],
"Flow Velocity": [0.55, 0.32, 0.25, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
"Pressure": [0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
}
)
# 复制生成10万样本
num_samples = 100
repetitions = num_samples // len(base_data) + 1
sample_data = pd.concat([base_data] * repetitions, ignore_index=True).head(
num_samples
)
print(f"生成了 {len(sample_data)} 个样本")
# 记录开始时间
start_time = time.time()
# 进行生存预测
survival_functions = analyzer.predict_survival(sample_data)
# 记录结束时间
end_time = time.time()
elapsed_time = end_time - start_time
print(f"预测耗时: {elapsed_time:.2f}")
# 打印预测结果示例
print(f"预测了 {len(survival_functions)} 个生存函数")
for i, sf in enumerate(survival_functions):
print(
f"样本 {i+1}: 时间点数量={len(sf.x)}, 生存概率范围={sf.y.min():.3f} - {sf.y.max():.3f} {sf}"
)
# 验证返回结果类型
assert isinstance(survival_functions, list), "返回值应为列表"
assert all(
hasattr(sf, "x") and hasattr(sf, "y") for sf in survival_functions
), "每个生存函数应包含x和y属性"
# 可选:测试绘图功能(不显示图表)
analyzer.plot_survival(survival_functions, show_plot=True)
if __name__ == "__main__":
import conftest
conftest.run_this_test(__file__) # 自定义运行,类似示例

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Any from typing import List, Optional, Any, Dict
from datetime import datetime from datetime import datetime
from psycopg import AsyncConnection from psycopg import AsyncConnection
import pandas as pd import pandas as pd
@@ -23,7 +23,7 @@ class CompositeQueries:
device_ids: List[str], device_ids: List[str],
start_time: datetime, start_time: datetime,
end_time: datetime, end_time: datetime,
) -> List[Optional[Any]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
获取 SCADA 关联的 link/node 模拟值 获取 SCADA 关联的 link/node 模拟值
@@ -39,12 +39,12 @@ class CompositeQueries:
field: 要查询的字段名 field: 要查询的字段名
Returns: Returns:
模拟数据值列表,如果没有找到则对应位置返回 None 模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises: Raises:
ValueError: 当 SCADA 设备未找到或字段无效时 ValueError: 当 SCADA 设备未找到或字段无效时
""" """
results = [] result = {}
# 1. 查询所有 SCADA 信息 # 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn) scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
@@ -75,8 +75,11 @@ class CompositeQueries:
) )
else: else:
raise ValueError(f"Unknown SCADA type: {scada_type}") raise ValueError(f"Unknown SCADA type: {scada_type}")
results.append(res) # 添加 scada_id 到每个数据项
return results for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod @staticmethod
async def get_scada_associated_scheme_simulation_data( async def get_scada_associated_scheme_simulation_data(
@@ -87,7 +90,7 @@ class CompositeQueries:
end_time: datetime, end_time: datetime,
scheme_type: str, scheme_type: str,
scheme_name: str, scheme_name: str,
) -> List[Optional[Any]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
获取 SCADA 关联的 link/node 模拟值 获取 SCADA 关联的 link/node 模拟值
@@ -103,12 +106,12 @@ class CompositeQueries:
field: 要查询的字段名 field: 要查询的字段名
Returns: Returns:
模拟数据值列表,如果没有找到则对应位置返回 None 模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises: Raises:
ValueError: 当 SCADA 设备未找到或字段无效时 ValueError: 当 SCADA 设备未找到或字段无效时
""" """
results = [] result = {}
# 1. 查询所有 SCADA 信息 # 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn) scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
@@ -151,8 +154,11 @@ class CompositeQueries:
) )
else: else:
raise ValueError(f"Unknown SCADA type: {scada_type}") raise ValueError(f"Unknown SCADA type: {scada_type}")
results.append(res) # 添加 scada_id 到每个数据项
return results for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod @staticmethod
async def get_element_associated_scada_data( async def get_element_associated_scada_data(
@@ -282,31 +288,31 @@ class CompositeQueries:
] ]
# 处理pressure数据 # 处理pressure数据
# if pressure_ids: if pressure_ids:
# pressure_df = df[pressure_ids] pressure_df = df[pressure_ids]
# # 重置索引,将 time 变为普通列 # 重置索引,将 time 变为普通列
# pressure_df = pressure_df.reset_index() pressure_df = pressure_df.reset_index()
# # 移除 time 列,准备输入给清洗方法 # 移除 time 列,准备输入给清洗方法
# value_df = pressure_df.drop(columns=["time"]) value_df = pressure_df.drop(columns=["time"])
# # 调用清洗方法 # 调用清洗方法
# cleaned_value_df = clean_pressure_data_df_km(value_df) cleaned_value_df = clean_pressure_data_df_km(value_df)
# # 添加 time 列到首列 # 添加 time 列到首列
# cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1) cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
# # 将清洗后的数据写回数据库 # 将清洗后的数据写回数据库
# for device_id in pressure_ids: for device_id in pressure_ids:
# if device_id in cleaned_df.columns: if device_id in cleaned_df.columns:
# cleaned_values = cleaned_df[device_id].tolist() cleaned_values = cleaned_df[device_id].tolist()
# time_values = cleaned_df["time"].tolist() time_values = cleaned_df["time"].tolist()
# for i, time_str in enumerate(time_values): for i, time_str in enumerate(time_values):
# time_dt = datetime.fromisoformat(time_str) time_dt = datetime.fromisoformat(time_str)
# value = cleaned_values[i] value = cleaned_values[i]
# await ScadaRepository.update_scada_field( await ScadaRepository.update_scada_field(
# timescale_conn, timescale_conn,
# time_dt, time_dt,
# device_id, device_id,
# "cleaned_value", "cleaned_value",
# value, value,
# ) )
# 处理flow数据 # 处理flow数据
if flow_ids: if flow_ids: