diff --git a/app/services/burst_location.py b/app/services/burst_location.py index eeb1f1f..d050ef0 100644 --- a/app/services/burst_location.py +++ b/app/services/burst_location.py @@ -130,11 +130,11 @@ def run_burst_location_by_network( end_dt=normal_end_dt, data_type="pressure", series_name="normal_pressure", - simulation_source="scheme", + simulation_source="realtime", simulation_scheme_name=simulation_scheme_name, simulation_scheme_type=resolved_simulation_scheme_type, ) - observed_source = "simulation_scheme_timerange" + observed_source = "simulation_scheme_burst_realtime_normal_timerange" else: burst_pressure_series, burst_pressure_samples = ( _build_observed_series_from_scada( @@ -208,7 +208,7 @@ def run_burst_location_by_network( end_dt=normal_end_dt, data_type="flow", series_name="normal_flow", - simulation_source="scheme", + simulation_source="realtime", simulation_scheme_name=simulation_scheme_name, simulation_scheme_type=resolved_simulation_scheme_type, ) @@ -270,6 +270,9 @@ def run_burst_location_by_network( "normal": normal_pressure_samples, }, "flow_samples": {"burst": burst_flow_samples, "normal": normal_flow_samples}, + "burst_leakage": burst_leakage, + "min_dpressure": min_dpressure, + "basic_pressure": basic_pressure, } if use_scada_pressure: payload["scada_window"] = { diff --git a/tests/unit/test_burst_location_service.py b/tests/unit/test_burst_location_service.py index 1511634..508db0c 100644 --- a/tests/unit/test_burst_location_service.py +++ b/tests/unit/test_burst_location_service.py @@ -78,6 +78,7 @@ def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, module = _load_burst_location_module() captured = {} scheme_calls = [] + realtime_calls = [] monkeypatch.setattr( module, @@ -138,6 +139,16 @@ def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, return {"J2": _build_series(kwargs["start_time"], values)} raise AssertionError(f"Unexpected scheme query: {kwargs}") + def fake_realtime_query(**kwargs): + realtime_calls.append(kwargs) + if kwargs["element_type"] == "node" and kwargs["field"] == "pressure": + return {"J1": _build_series(kwargs["start_time"], [8.0, 10.0, 12.0, 14.0])} + if kwargs["element_type"] == "link" and kwargs["field"] == "flow": + return {"P1": _build_series(kwargs["start_time"], [2.0, 4.0, 6.0, 8.0])} + if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand": + return {"J2": _build_series(kwargs["start_time"], [1.0, 3.0, 5.0, 7.0])} + raise AssertionError(f"Unexpected realtime query: {kwargs}") + monkeypatch.setattr( module.InternalQueries, "query_scheme_simulation_by_ids_timerange", @@ -146,7 +157,7 @@ def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, monkeypatch.setattr( module.InternalQueries, "query_realtime_simulation_by_ids_timerange", - staticmethod(lambda **kwargs: pytest.fail(f"Unexpected realtime query: {kwargs}")), + staticmethod(fake_realtime_query), ) result = module.run_burst_location_by_network( @@ -162,7 +173,7 @@ def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, use_scada_flow=True, ) - assert result["observed_source"] == "simulation_scheme_timerange" + assert result["observed_source"] == "simulation_scheme_burst_realtime_normal_timerange" assert result["simulation_scheme"] == { "name": "BurstSchemeA", "type": "burst_analysis", @@ -177,9 +188,14 @@ def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, assert captured["normal_flow"]["J2"] == pytest.approx(4.0) assert captured["normal_flow"]["P1"] == pytest.approx(5.0) assert all(call["scheme_name"] == "BurstSchemeA" for call in scheme_calls) + assert len(scheme_calls) == 3 assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in scheme_calls) assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls) assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls) + assert len(realtime_calls) == 3 + assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in realtime_calls) + assert any(call["element_type"] == "link" and call["field"] == "flow" for call in realtime_calls) + assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in realtime_calls) def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_path):