新增tests;新增管道健康预测方法;更新复合查询函数输出格式

This commit is contained in:
JIANG
2025-12-15 18:31:29 +08:00
parent ea33fc270d
commit 9b5707841b
7 changed files with 261 additions and 36 deletions

3
api_ex/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .Fdataclean import *
from .Pdataclean import *
from .pipeline_health_analyzer import *

Binary file not shown.

View File

@@ -0,0 +1,140 @@
import os
import joblib
import pandas as pd
import numpy as np
from sksurv.ensemble import RandomSurvivalForest
import matplotlib.pyplot as plt
class PipelineHealthAnalyzer:
"""
管道健康分析器类,使用随机生存森林模型预测管道的生存概率。
该类封装了模型加载和预测功能,便于在其他项目中复用。
模型基于4个特征进行生存分析预测材料、直径、流速、压力。
使用前需确保安装依赖joblib, pandas, numpy, scikit-survival, matplotlib。
"""
def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'):
"""
初始化分析器,加载预训练的随机生存森林模型。
:param model_path: 模型文件的路径(默认为相对路径 'model/my_survival_forest_model_quxi.joblib')。
:raises FileNotFoundError: 如果模型文件不存在。
:raises Exception: 如果模型加载失败。
"""
# 确保 model 目录存在
model_dir = os.path.dirname(model_path)
if model_dir and not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件未找到: {model_path}")
try:
self.rsf = joblib.load(model_path)
self.features = [
'Material', 'Diameter', 'Flow Velocity',
'Pressure', # 'Temperature', 'Precipitation',
# 'Location', 'Structural Defects', 'Functional Defects'
]
except Exception as e:
raise Exception(f"加载模型时出错: {str(e)}")
def predict_survival(self, data: pd.DataFrame) -> list:
"""
基于输入数据预测生存函数。
:param data: pandas DataFrame包含4个必需特征列。数据应为数值型或可转换为数值型。
:return: 生存函数列表每个元素为一个生存函数对象包含时间点x和生存概率y
:raises ValueError: 如果数据缺少必需特征或格式不正确。
"""
# 检查必需特征是否存在
missing_features = [feat for feat in self.features if feat not in data.columns]
if missing_features:
raise ValueError(f"数据缺少必需特征: {missing_features}")
# 提取特征数据
try:
x_test = data[self.features].astype(float) # 确保数值型
except ValueError as e:
raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}")
# 进行预测
survival_functions = self.rsf.predict_survival_function(x_test)
return list(survival_functions)
def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True):
"""
可视化生存函数,生成生存概率图表。
:param survival_functions: predict_survival返回的生存函数列表。
:param save_path: 可选,保存图表的路径(.png格式。如果为None则不保存。
:param show_plot: 是否显示图表(在交互环境中)。
"""
plt.figure(figsize=(10, 6))
for i, sf in enumerate(survival_functions):
plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}')
plt.xlabel("时间(年)")
plt.ylabel("生存概率")
plt.title("管道生存概率预测")
plt.legend()
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"图表已保存到: {save_path}")
if show_plot:
plt.show()
else:
plt.close()
# 调用说明示例
"""
在其他项目中使用PipelineHealthAnalyzer类的步骤
1. 安装依赖在requirements.txt中添加
joblib==1.5.0
pandas==2.2.3
numpy==2.0.2
scikit-survival==0.23.1
matplotlib==3.9.4
2. 导入类:
from pipeline_health_analyzer import PipelineHealthAnalyzer
3. 初始化分析器(替换为实际模型路径):
analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib')
4. 准备数据pandas DataFrame包含9个特征列
import pandas as pd
data = pd.DataFrame({
'Material': [1, 2], # 示例数据
'Diameter': [100, 150],
'Flow Velocity': [1.5, 2.0],
'Pressure': [50, 60],
'Temperature': [20, 25],
'Precipitation': [0.1, 0.2],
'Location': [1, 2],
'Structural Defects': [0, 1],
'Functional Defects': [0, 0]
})
5. 进行预测:
survival_funcs = analyzer.predict_survival(data)
6. 查看结果(每个样本的生存概率随时间变化):
for i, sf in enumerate(survival_funcs):
print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...")
7. 可视化(可选):
analyzer.plot_survival(survival_funcs, save_path='survival_plot.png')
注意:
- 数据格式必须匹配特征列表,特征值为数值型。
- 模型文件需从原项目复制或重新训练。
- 如果需要自定义特征或模型参数可修改类中的features列表或继承此类。
"""

0
tests/__init__.py Normal file
View File

14
tests/conftest.py Normal file
View File

@@ -0,0 +1,14 @@
import pytest
import sys
import os
# 自动添加项目根目录到路径(处理项目结构)
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
def run_this_test(test_file):
"""自定义函数运行单个测试文件类似pytest"""
# 提取测试文件名(无扩展名)
test_name = os.path.splitext(os.path.basename(test_file))[0]
# 使用pytest运行自动处理导入
pytest.main([test_file, "-v"])

View File

@@ -0,0 +1,62 @@
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
def test_pipeline_health_analyzer():
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
analyzer = PipelineHealthAnalyzer(
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
)
# 创建示例输入数据9个样本
import pandas as pd
import time
base_data = pd.DataFrame(
{
"Material": [7, 11, 7, 7, 7, 7, 7, 7, 7],
"Diameter": [6, 6, 3, 2, 3, 3, 2, 2, 2],
"Flow Velocity": [0.55, 0.32, 0.25, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
"Pressure": [0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
}
)
# 复制生成10万样本
num_samples = 100
repetitions = num_samples // len(base_data) + 1
sample_data = pd.concat([base_data] * repetitions, ignore_index=True).head(
num_samples
)
print(f"生成了 {len(sample_data)} 个样本")
# 记录开始时间
start_time = time.time()
# 进行生存预测
survival_functions = analyzer.predict_survival(sample_data)
# 记录结束时间
end_time = time.time()
elapsed_time = end_time - start_time
print(f"预测耗时: {elapsed_time:.2f}")
# 打印预测结果示例
print(f"预测了 {len(survival_functions)} 个生存函数")
for i, sf in enumerate(survival_functions):
print(
f"样本 {i+1}: 时间点数量={len(sf.x)}, 生存概率范围={sf.y.min():.3f} - {sf.y.max():.3f} {sf}"
)
# 验证返回结果类型
assert isinstance(survival_functions, list), "返回值应为列表"
assert all(
hasattr(sf, "x") and hasattr(sf, "y") for sf in survival_functions
), "每个生存函数应包含x和y属性"
# 可选:测试绘图功能(不显示图表)
analyzer.plot_survival(survival_functions, show_plot=True)
if __name__ == "__main__":
import conftest
conftest.run_this_test(__file__) # 自定义运行,类似示例

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Any
from typing import List, Optional, Any, Dict
from datetime import datetime
from psycopg import AsyncConnection
import pandas as pd
@@ -23,7 +23,7 @@ class CompositeQueries:
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> List[Optional[Any]]:
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 SCADA 关联的 link/node 模拟值
@@ -39,12 +39,12 @@ class CompositeQueries:
field: 要查询的字段名
Returns:
模拟数据值列表,如果没有找到则对应位置返回 None
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
results = []
result = {}
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
@@ -75,8 +75,11 @@ class CompositeQueries:
)
else:
raise ValueError(f"Unknown SCADA type: {scada_type}")
results.append(res)
return results
# 添加 scada_id 到每个数据项
for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod
async def get_scada_associated_scheme_simulation_data(
@@ -87,7 +90,7 @@ class CompositeQueries:
end_time: datetime,
scheme_type: str,
scheme_name: str,
) -> List[Optional[Any]]:
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 SCADA 关联的 link/node 模拟值
@@ -103,12 +106,12 @@ class CompositeQueries:
field: 要查询的字段名
Returns:
模拟数据值列表,如果没有找到则对应位置返回 None
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
results = []
result = {}
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
@@ -151,8 +154,11 @@ class CompositeQueries:
)
else:
raise ValueError(f"Unknown SCADA type: {scada_type}")
results.append(res)
return results
# 添加 scada_id 到每个数据项
for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod
async def get_element_associated_scada_data(
@@ -282,31 +288,31 @@ class CompositeQueries:
]
# 处理pressure数据
# if pressure_ids:
# pressure_df = df[pressure_ids]
# # 重置索引,将 time 变为普通列
# pressure_df = pressure_df.reset_index()
# # 移除 time 列,准备输入给清洗方法
# value_df = pressure_df.drop(columns=["time"])
# # 调用清洗方法
# cleaned_value_df = clean_pressure_data_df_km(value_df)
# # 添加 time 列到首列
# cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
# # 将清洗后的数据写回数据库
# for device_id in pressure_ids:
# if device_id in cleaned_df.columns:
# cleaned_values = cleaned_df[device_id].tolist()
# time_values = cleaned_df["time"].tolist()
# for i, time_str in enumerate(time_values):
# time_dt = datetime.fromisoformat(time_str)
# value = cleaned_values[i]
# await ScadaRepository.update_scada_field(
# timescale_conn,
# time_dt,
# device_id,
# "cleaned_value",
# value,
# )
if pressure_ids:
pressure_df = df[pressure_ids]
# 重置索引,将 time 变为普通列
pressure_df = pressure_df.reset_index()
# 移除 time 列,准备输入给清洗方法
value_df = pressure_df.drop(columns=["time"])
# 调用清洗方法
cleaned_value_df = clean_pressure_data_df_km(value_df)
# 添加 time 列到首列
cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
# 将清洗后的数据写回数据库
for device_id in pressure_ids:
if device_id in cleaned_df.columns:
cleaned_values = cleaned_df[device_id].tolist()
time_values = cleaned_df["time"].tolist()
for i, time_str in enumerate(time_values):
time_dt = datetime.fromisoformat(time_str)
value = cleaned_values[i]
await ScadaRepository.update_scada_field(
timescale_conn,
time_dt,
device_id,
"cleaned_value",
value,
)
# 处理flow数据
if flow_ids: