Files
TJWaterServerBinary/tests/test_pipeline_health_analyzer.py

63 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
def test_pipeline_health_analyzer():
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
analyzer = PipelineHealthAnalyzer(
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
)
# 创建示例输入数据9个样本
import pandas as pd
import time
base_data = pd.DataFrame(
{
"Material": [7, 11, 7, 7, 7, 7, 7, 7, 7],
"Diameter": [6, 6, 3, 2, 3, 3, 2, 2, 2],
"Flow Velocity": [0.55, 0.32, 0.25, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
"Pressure": [0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
}
)
# 复制生成10万样本
num_samples = 100
repetitions = num_samples // len(base_data) + 1
sample_data = pd.concat([base_data] * repetitions, ignore_index=True).head(
num_samples
)
print(f"生成了 {len(sample_data)} 个样本")
# 记录开始时间
start_time = time.time()
# 进行生存预测
survival_functions = analyzer.predict_survival(sample_data)
# 记录结束时间
end_time = time.time()
elapsed_time = end_time - start_time
print(f"预测耗时: {elapsed_time:.2f}")
# 打印预测结果示例
print(f"预测了 {len(survival_functions)} 个生存函数")
for i, sf in enumerate(survival_functions):
print(
f"样本 {i+1}: 时间点数量={len(sf.x)}, 生存概率范围={sf.y.min():.3f} - {sf.y.max():.3f} {sf}"
)
# 验证返回结果类型
assert isinstance(survival_functions, list), "返回值应为列表"
assert all(
hasattr(sf, "x") and hasattr(sf, "y") for sf in survival_functions
), "每个生存函数应包含x和y属性"
# 可选:测试绘图功能(不显示图表)
analyzer.plot_survival(survival_functions, show_plot=True)
if __name__ == "__main__":
import conftest
conftest.run_this_test(__file__) # 自定义运行,类似示例