diff --git a/api_ex/__init__.py b/api_ex/__init__.py new file mode 100644 index 0000000..dcd7666 --- /dev/null +++ b/api_ex/__init__.py @@ -0,0 +1,3 @@ +from .Fdataclean import * +from .Pdataclean import * +from .pipeline_health_analyzer import * \ No newline at end of file diff --git a/api_ex/model/my_survival_forest_model_quxi.zip b/api_ex/model/my_survival_forest_model_quxi.zip new file mode 100644 index 0000000..6ad9f68 Binary files /dev/null and b/api_ex/model/my_survival_forest_model_quxi.zip differ diff --git a/api_ex/pipeline_health_analyzer.py b/api_ex/pipeline_health_analyzer.py new file mode 100644 index 0000000..7cdb699 --- /dev/null +++ b/api_ex/pipeline_health_analyzer.py @@ -0,0 +1,140 @@ +import os +import joblib +import pandas as pd +import numpy as np +from sksurv.ensemble import RandomSurvivalForest +import matplotlib.pyplot as plt + + +class PipelineHealthAnalyzer: + """ + 管道健康分析器类,使用随机生存森林模型预测管道的生存概率。 + + 该类封装了模型加载和预测功能,便于在其他项目中复用。 + 模型基于4个特征进行生存分析预测:材料、直径、流速、压力。 + + 使用前需确保安装依赖:joblib, pandas, numpy, scikit-survival, matplotlib。 + """ + + def __init__(self, model_path: str = 'model/my_survival_forest_model_quxi.joblib'): + """ + 初始化分析器,加载预训练的随机生存森林模型。 + + :param model_path: 模型文件的路径(默认为相对路径 'model/my_survival_forest_model_quxi.joblib')。 + :raises FileNotFoundError: 如果模型文件不存在。 + :raises Exception: 如果模型加载失败。 + """ + # 确保 model 目录存在 + model_dir = os.path.dirname(model_path) + if model_dir and not os.path.exists(model_dir): + os.makedirs(model_dir, exist_ok=True) + + if not os.path.exists(model_path): + raise FileNotFoundError(f"模型文件未找到: {model_path}") + + try: + self.rsf = joblib.load(model_path) + self.features = [ + 'Material', 'Diameter', 'Flow Velocity', + 'Pressure', # 'Temperature', 'Precipitation', + # 'Location', 'Structural Defects', 'Functional Defects' + ] + except Exception as e: + raise Exception(f"加载模型时出错: {str(e)}") + + def predict_survival(self, data: pd.DataFrame) -> list: + """ + 基于输入数据预测生存函数。 + + :param data: pandas DataFrame,包含4个必需特征列。数据应为数值型或可转换为数值型。 + :return: 生存函数列表,每个元素为一个生存函数对象(包含时间点x和生存概率y)。 + :raises ValueError: 如果数据缺少必需特征或格式不正确。 + """ + # 检查必需特征是否存在 + missing_features = [feat for feat in self.features if feat not in data.columns] + if missing_features: + raise ValueError(f"数据缺少必需特征: {missing_features}") + + # 提取特征数据 + try: + x_test = data[self.features].astype(float) # 确保数值型 + except ValueError as e: + raise ValueError(f"特征数据转换失败,请检查数据类型: {str(e)}") + + # 进行预测 + survival_functions = self.rsf.predict_survival_function(x_test) + return list(survival_functions) + + def plot_survival(self, survival_functions: list, save_path: str = None, show_plot: bool = True): + """ + 可视化生存函数,生成生存概率图表。 + + :param survival_functions: predict_survival返回的生存函数列表。 + :param save_path: 可选,保存图表的路径(.png格式)。如果为None,则不保存。 + :param show_plot: 是否显示图表(在交互环境中)。 + """ + plt.figure(figsize=(10, 6)) + for i, sf in enumerate(survival_functions): + plt.step(sf.x, sf.y, where='post', label=f'样本 {i + 1}') + plt.xlabel("时间(年)") + plt.ylabel("生存概率") + plt.title("管道生存概率预测") + plt.legend() + plt.grid(True, alpha=0.3) + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + print(f"图表已保存到: {save_path}") + + if show_plot: + plt.show() + else: + plt.close() + + +# 调用说明示例 +""" +在其他项目中使用PipelineHealthAnalyzer类的步骤: + +1. 安装依赖(在requirements.txt中添加): + joblib==1.5.0 + pandas==2.2.3 + numpy==2.0.2 + scikit-survival==0.23.1 + matplotlib==3.9.4 + +2. 导入类: + from pipeline_health_analyzer import PipelineHealthAnalyzer + +3. 初始化分析器(替换为实际模型路径): + analyzer = PipelineHealthAnalyzer(model_path='path/to/my_survival_forest_model3-10.joblib') + +4. 准备数据(pandas DataFrame,包含9个特征列): + import pandas as pd + data = pd.DataFrame({ + 'Material': [1, 2], # 示例数据 + 'Diameter': [100, 150], + 'Flow Velocity': [1.5, 2.0], + 'Pressure': [50, 60], + 'Temperature': [20, 25], + 'Precipitation': [0.1, 0.2], + 'Location': [1, 2], + 'Structural Defects': [0, 1], + 'Functional Defects': [0, 0] + }) + +5. 进行预测: + survival_funcs = analyzer.predict_survival(data) + +6. 查看结果(每个样本的生存概率随时间变化): + for i, sf in enumerate(survival_funcs): + print(f"样本 {i+1}: 时间点: {sf.x[:5]}..., 生存概率: {sf.y[:5]}...") + +7. 可视化(可选): + analyzer.plot_survival(survival_funcs, save_path='survival_plot.png') + +注意: +- 数据格式必须匹配特征列表,特征值为数值型。 +- 模型文件需从原项目复制或重新训练。 +- 如果需要自定义特征或模型参数,可修改类中的features列表或继承此类。 +""" \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..b1c7e7b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,14 @@ +import pytest +import sys +import os + +# 自动添加项目根目录到路径(处理项目结构) +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + + +def run_this_test(test_file): + """自定义函数:运行单个测试文件(类似pytest)""" + # 提取测试文件名(无扩展名) + test_name = os.path.splitext(os.path.basename(test_file))[0] + # 使用pytest运行(自动处理导入) + pytest.main([test_file, "-v"]) diff --git a/tests/test_pipeline_health_analyzer.py b/tests/test_pipeline_health_analyzer.py new file mode 100644 index 0000000..489ff31 --- /dev/null +++ b/tests/test_pipeline_health_analyzer.py @@ -0,0 +1,62 @@ +from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer + + +def test_pipeline_health_analyzer(): + # 初始化分析器,假设模型文件路径为'models/rsf_model.joblib' + analyzer = PipelineHealthAnalyzer( + model_path="api_ex/model/my_survival_forest_model_quxi.joblib" + ) + # 创建示例输入数据(9个样本) + import pandas as pd + import time + + base_data = pd.DataFrame( + { + "Material": [7, 11, 7, 7, 7, 7, 7, 7, 7], + "Diameter": [6, 6, 3, 2, 3, 3, 2, 2, 2], + "Flow Velocity": [0.55, 0.32, 0.25, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4], + "Pressure": [0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], + } + ) + + # 复制生成10万样本 + num_samples = 100 + repetitions = num_samples // len(base_data) + 1 + sample_data = pd.concat([base_data] * repetitions, ignore_index=True).head( + num_samples + ) + + print(f"生成了 {len(sample_data)} 个样本") + + # 记录开始时间 + start_time = time.time() + + # 进行生存预测 + survival_functions = analyzer.predict_survival(sample_data) + + # 记录结束时间 + end_time = time.time() + elapsed_time = end_time - start_time + print(f"预测耗时: {elapsed_time:.2f} 秒") + + # 打印预测结果示例 + print(f"预测了 {len(survival_functions)} 个生存函数") + for i, sf in enumerate(survival_functions): + print( + f"样本 {i+1}: 时间点数量={len(sf.x)}, 生存概率范围={sf.y.min():.3f} - {sf.y.max():.3f} {sf}" + ) + + # 验证返回结果类型 + assert isinstance(survival_functions, list), "返回值应为列表" + assert all( + hasattr(sf, "x") and hasattr(sf, "y") for sf in survival_functions + ), "每个生存函数应包含x和y属性" + + # 可选:测试绘图功能(不显示图表) + analyzer.plot_survival(survival_functions, show_plot=True) + + +if __name__ == "__main__": + import conftest + + conftest.run_this_test(__file__) # 自定义运行,类似示例 diff --git a/timescaledb/composite_queries.py b/timescaledb/composite_queries.py index 61de4c4..813b556 100644 --- a/timescaledb/composite_queries.py +++ b/timescaledb/composite_queries.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Any +from typing import List, Optional, Any, Dict from datetime import datetime from psycopg import AsyncConnection import pandas as pd @@ -23,7 +23,7 @@ class CompositeQueries: device_ids: List[str], start_time: datetime, end_time: datetime, - ) -> List[Optional[Any]]: + ) -> Dict[str, List[Dict[str, Any]]]: """ 获取 SCADA 关联的 link/node 模拟值 @@ -39,12 +39,12 @@ class CompositeQueries: field: 要查询的字段名 Returns: - 模拟数据值列表,如果没有找到则对应位置返回 None + 模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id Raises: ValueError: 当 SCADA 设备未找到或字段无效时 """ - results = [] + result = {} # 1. 查询所有 SCADA 信息 scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn) @@ -75,8 +75,11 @@ class CompositeQueries: ) else: raise ValueError(f"Unknown SCADA type: {scada_type}") - results.append(res) - return results + # 添加 scada_id 到每个数据项 + for item in res: + item["scada_id"] = device_id + result[device_id] = res + return result @staticmethod async def get_scada_associated_scheme_simulation_data( @@ -87,7 +90,7 @@ class CompositeQueries: end_time: datetime, scheme_type: str, scheme_name: str, - ) -> List[Optional[Any]]: + ) -> Dict[str, List[Dict[str, Any]]]: """ 获取 SCADA 关联的 link/node 模拟值 @@ -103,12 +106,12 @@ class CompositeQueries: field: 要查询的字段名 Returns: - 模拟数据值列表,如果没有找到则对应位置返回 None + 模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id Raises: ValueError: 当 SCADA 设备未找到或字段无效时 """ - results = [] + result = {} # 1. 查询所有 SCADA 信息 scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn) @@ -151,8 +154,11 @@ class CompositeQueries: ) else: raise ValueError(f"Unknown SCADA type: {scada_type}") - results.append(res) - return results + # 添加 scada_id 到每个数据项 + for item in res: + item["scada_id"] = device_id + result[device_id] = res + return result @staticmethod async def get_element_associated_scada_data( @@ -282,31 +288,31 @@ class CompositeQueries: ] # 处理pressure数据 - # if pressure_ids: - # pressure_df = df[pressure_ids] - # # 重置索引,将 time 变为普通列 - # pressure_df = pressure_df.reset_index() - # # 移除 time 列,准备输入给清洗方法 - # value_df = pressure_df.drop(columns=["time"]) - # # 调用清洗方法 - # cleaned_value_df = clean_pressure_data_df_km(value_df) - # # 添加 time 列到首列 - # cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1) - # # 将清洗后的数据写回数据库 - # for device_id in pressure_ids: - # if device_id in cleaned_df.columns: - # cleaned_values = cleaned_df[device_id].tolist() - # time_values = cleaned_df["time"].tolist() - # for i, time_str in enumerate(time_values): - # time_dt = datetime.fromisoformat(time_str) - # value = cleaned_values[i] - # await ScadaRepository.update_scada_field( - # timescale_conn, - # time_dt, - # device_id, - # "cleaned_value", - # value, - # ) + if pressure_ids: + pressure_df = df[pressure_ids] + # 重置索引,将 time 变为普通列 + pressure_df = pressure_df.reset_index() + # 移除 time 列,准备输入给清洗方法 + value_df = pressure_df.drop(columns=["time"]) + # 调用清洗方法 + cleaned_value_df = clean_pressure_data_df_km(value_df) + # 添加 time 列到首列 + cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1) + # 将清洗后的数据写回数据库 + for device_id in pressure_ids: + if device_id in cleaned_df.columns: + cleaned_values = cleaned_df[device_id].tolist() + time_values = cleaned_df["time"].tolist() + for i, time_str in enumerate(time_values): + time_dt = datetime.fromisoformat(time_str) + value = cleaned_values[i] + await ScadaRepository.update_scada_field( + timescale_conn, + time_dt, + device_id, + "cleaned_value", + value, + ) # 处理flow数据 if flow_ids: