65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
"""
|
||
tests.unit.test_pipeline_health_analyzer 的 Docstring
|
||
"""
|
||
|
||
|
||
def test_pipeline_health_analyzer():
|
||
from app.algorithms.api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
|
||
|
||
# 初始化分析器,假设模型文件路径为'models/rsf_model.joblib'
|
||
analyzer = PipelineHealthAnalyzer()
|
||
# 创建示例输入数据(9个样本)
|
||
import pandas as pd
|
||
import time
|
||
|
||
base_data = pd.DataFrame(
|
||
{
|
||
"Material": [7, 11, 7, 7, 7, 7, 7, 7, 7],
|
||
"Diameter": [6, 6, 3, 2, 3, 3, 2, 2, 2],
|
||
"Flow Velocity": [0.55, 0.32, 0.25, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
|
||
"Pressure": [0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],
|
||
}
|
||
)
|
||
|
||
# 复制生成10万样本
|
||
num_samples = 100
|
||
repetitions = num_samples // len(base_data) + 1
|
||
sample_data = pd.concat([base_data] * repetitions, ignore_index=True).head(
|
||
num_samples
|
||
)
|
||
|
||
print(f"生成了 {len(sample_data)} 个样本")
|
||
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
|
||
# 进行生存预测
|
||
survival_functions = analyzer.predict_survival(sample_data)
|
||
|
||
# 记录结束时间
|
||
end_time = time.time()
|
||
elapsed_time = end_time - start_time
|
||
print(f"预测耗时: {elapsed_time:.2f} 秒")
|
||
|
||
# 打印预测结果示例
|
||
print(f"预测了 {len(survival_functions)} 个生存函数")
|
||
for i, sf in enumerate(survival_functions):
|
||
print(
|
||
f"样本 {i+1}: 时间点数量={len(sf.x)}, 生存概率范围={sf.y.min():.3f} - {sf.y.max():.3f} {sf}"
|
||
)
|
||
|
||
# 验证返回结果类型
|
||
assert isinstance(survival_functions, list), "返回值应为列表"
|
||
assert all(
|
||
hasattr(sf, "x") and hasattr(sf, "y") for sf in survival_functions
|
||
), "每个生存函数应包含x和y属性"
|
||
|
||
# 可选:测试绘图功能(不显示图表)
|
||
analyzer.plot_survival(survival_functions, show_plot=False)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import conftest
|
||
|
||
conftest.run_this_test(__file__) # 自定义运行,类似示例
|