新增tests;新增管道健康预测方法;更新复合查询函数输出格式
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
14
tests/conftest.py
Normal file
14
tests/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 自动添加项目根目录到路径(处理项目结构)
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def run_this_test(test_file):
|
||||
"""自定义函数:运行单个测试文件(类似pytest)"""
|
||||
# 提取测试文件名(无扩展名)
|
||||
test_name = os.path.splitext(os.path.basename(test_file))[0]
|
||||
# 使用pytest运行(自动处理导入)
|
||||
pytest.main([test_file, "-v"])
|
||||
62
tests/test_pipeline_health_analyzer.py
Normal file
62
tests/test_pipeline_health_analyzer.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
|
||||
|
||||
|
||||
def test_pipeline_health_analyzer():
|
||||
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
|
||||
analyzer = PipelineHealthAnalyzer(
|
||||
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
|
||||
)
|
||||
# 创建示例输入数据(9个样本)
|
||||
import pandas as pd
|
||||
import time
|
||||
|
||||
base_data = pd.DataFrame(
|
||||
{
|
||||
"Material": [7, 11, 7, 7, 7, 7, 7, 7, 7],
|
||||
"Diameter": [6, 6, 3, 2, 3, 3, 2, 2, 2],
|
||||
"Flow Velocity": [0.55, 0.32, 0.25, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
|
||||
"Pressure": [0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
|
||||
}
|
||||
)
|
||||
|
||||
# 复制生成10万样本
|
||||
num_samples = 100
|
||||
repetitions = num_samples // len(base_data) + 1
|
||||
sample_data = pd.concat([base_data] * repetitions, ignore_index=True).head(
|
||||
num_samples
|
||||
)
|
||||
|
||||
print(f"生成了 {len(sample_data)} 个样本")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 进行生存预测
|
||||
survival_functions = analyzer.predict_survival(sample_data)
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"预测耗时: {elapsed_time:.2f} 秒")
|
||||
|
||||
# 打印预测结果示例
|
||||
print(f"预测了 {len(survival_functions)} 个生存函数")
|
||||
for i, sf in enumerate(survival_functions):
|
||||
print(
|
||||
f"样本 {i+1}: 时间点数量={len(sf.x)}, 生存概率范围={sf.y.min():.3f} - {sf.y.max():.3f} {sf}"
|
||||
)
|
||||
|
||||
# 验证返回结果类型
|
||||
assert isinstance(survival_functions, list), "返回值应为列表"
|
||||
assert all(
|
||||
hasattr(sf, "x") and hasattr(sf, "y") for sf in survival_functions
|
||||
), "每个生存函数应包含x和y属性"
|
||||
|
||||
# 可选:测试绘图功能(不显示图表)
|
||||
analyzer.plot_survival(survival_functions, show_plot=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import conftest
|
||||
|
||||
conftest.run_this_test(__file__) # 自定义运行,类似示例
|
||||
Reference in New Issue
Block a user