import importlib.util import sys import types from datetime import datetime, timedelta from pathlib import Path import pytest def _load_burst_location_module(): module_path = ( Path(__file__).resolve().parents[2] / "app" / "services" / "burst_location.py" ) def ensure_package(name: str) -> types.ModuleType: module = sys.modules.get(name) if module is None: module = types.ModuleType(name) module.__path__ = [] sys.modules[name] = module return module for package_name in [ "app", "app.algorithms", "app.infra", "app.infra.db", "app.infra.db.timescaledb", "app.services", ]: ensure_package(package_name) algorithms_module = types.ModuleType("app.algorithms.burst_location") algorithms_module.run_burst_location = lambda **kwargs: {} sys.modules["app.algorithms.burst_location"] = algorithms_module internal_queries_module = types.ModuleType( "app.infra.db.timescaledb.internal_queries" ) class DummyInternalQueries: @staticmethod def query_scada_by_ids_timerange(**kwargs): return {} @staticmethod def query_scheme_simulation_by_ids_timerange(**kwargs): return {} @staticmethod def query_realtime_simulation_by_ids_timerange(**kwargs): return {} internal_queries_module.InternalQueries = DummyInternalQueries sys.modules["app.infra.db.timescaledb.internal_queries"] = internal_queries_module scheme_management_module = types.ModuleType("app.services.scheme_management") scheme_management_module.query_burst_location_scheme_detail = lambda *args, **kwargs: {} scheme_management_module.query_burst_location_schemes = lambda *args, **kwargs: [] scheme_management_module.scheme_name_exists = lambda *args, **kwargs: False scheme_management_module.store_scheme_info = lambda *args, **kwargs: None sys.modules["app.services.scheme_management"] = scheme_management_module tjnetwork_module = types.ModuleType("app.services.tjnetwork") tjnetwork_module.dump_inp = lambda *args, **kwargs: None tjnetwork_module.get_all_scada_info = lambda *args, **kwargs: [] sys.modules["app.services.tjnetwork"] = tjnetwork_module module_name = "tests_burst_location_under_test" spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) assert spec and spec.loader spec.loader.exec_module(module) return module def test_run_burst_location_uses_single_timerange_with_burst_source_split(monkeypatch, tmp_path): module = _load_burst_location_module() captured = {} scheme_calls = [] realtime_calls = [] monkeypatch.setattr( module, "get_all_scada_info", lambda network: [ { "type": "pressure", "associated_element_id": "J1", "api_query_id": "pressure-query", }, { "type": "pipe_flow", "associated_element_id": "P1", "api_query_id": "pipe-flow-query", }, { "type": "demand", "associated_element_id": "J2", "api_query_id": "demand-query", }, ], ) monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp")) def fake_run_burst_location(**kwargs): captured.update(kwargs) return { "located_pipe": "Pipe-001", "simulation_times": 3, "similarity_mode": "combined", } monkeypatch.setattr(module, "run_burst_location", fake_run_burst_location) def _build_series(start_time: str, values: list[float]) -> list[dict]: base_time = datetime.fromisoformat(start_time) return [ { "time": (base_time + timedelta(minutes=15 * index)).isoformat(), "value": value, } for index, value in enumerate(values) ] def fake_scheme_query(**kwargs): scheme_calls.append(kwargs) if kwargs["element_type"] == "node" and kwargs["field"] == "pressure": start_hour = datetime.fromisoformat(kwargs["start_time"]).hour values = [12.0, 14.0, 16.0, 18.0] if start_hour == 8 else [8.0, 10.0, 12.0, 14.0] return {"J1": _build_series(kwargs["start_time"], values)} if kwargs["element_type"] == "link" and kwargs["field"] == "flow": start_hour = datetime.fromisoformat(kwargs["start_time"]).hour values = [5.0, 7.0, 9.0, 11.0] if start_hour == 8 else [2.0, 4.0, 6.0, 8.0] return {"P1": _build_series(kwargs["start_time"], values)} if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand": start_hour = datetime.fromisoformat(kwargs["start_time"]).hour values = [3.0, 5.0, 7.0, 9.0] if start_hour == 8 else [1.0, 3.0, 5.0, 7.0] return {"J2": _build_series(kwargs["start_time"], values)} raise AssertionError(f"Unexpected scheme query: {kwargs}") def fake_realtime_query(**kwargs): realtime_calls.append(kwargs) if kwargs["element_type"] == "node" and kwargs["field"] == "pressure": return {"J1": _build_series(kwargs["start_time"], [8.0, 10.0, 12.0, 14.0])} if kwargs["element_type"] == "link" and kwargs["field"] == "flow": return {"P1": _build_series(kwargs["start_time"], [2.0, 4.0, 6.0, 8.0])} if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand": return {"J2": _build_series(kwargs["start_time"], [1.0, 3.0, 5.0, 7.0])} raise AssertionError(f"Unexpected realtime query: {kwargs}") monkeypatch.setattr( module.InternalQueries, "query_scheme_simulation_by_ids_timerange", staticmethod(fake_scheme_query), ) monkeypatch.setattr( module.InternalQueries, "query_realtime_simulation_by_ids_timerange", staticmethod(fake_realtime_query), ) result = module.run_burst_location_by_network( network="tjwater", username="testuser", data_source="simulation", simulation_scheme_name="BurstSchemeA", simulation_scheme_type="burst_analysis", burst_leakage=10.0, scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), use_scada_flow=True, ) assert result["observed_source"] == "simulation_scheme_burst_realtime_normal_timerange" assert result["simulation_scheme"] == { "name": "BurstSchemeA", "type": "burst_analysis", } assert result["pressure_samples"] == {"burst": 4, "normal": 4} assert result["flow_samples"] == {"burst": 4, "normal": 4} assert list(captured["burst_pressure"].index) == ["J1"] assert captured["burst_pressure"]["J1"] == pytest.approx(15.0) assert captured["normal_pressure"]["J1"] == pytest.approx(11.0) assert captured["burst_flow"]["J2"] == pytest.approx(6.0) assert captured["burst_flow"]["P1"] == pytest.approx(8.0) assert captured["normal_flow"]["J2"] == pytest.approx(4.0) assert captured["normal_flow"]["P1"] == pytest.approx(5.0) assert all(call["scheme_name"] == "BurstSchemeA" for call in scheme_calls) assert len(scheme_calls) == 3 assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in scheme_calls) assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls) assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls) assert len(realtime_calls) == 3 assert all(datetime.fromisoformat(call["start_time"]).hour == 8 for call in realtime_calls) assert all(datetime.fromisoformat(call["end_time"]).hour == 9 for call in realtime_calls) assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in realtime_calls) assert any(call["element_type"] == "link" and call["field"] == "flow" for call in realtime_calls) assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in realtime_calls) assert result["scada_window"] == { "burst_start": "2025-01-01T08:00:00", "burst_end": "2025-01-01T09:00:00", } def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_path): module = _load_burst_location_module() monkeypatch.setattr( module, "get_all_scada_info", lambda network: [ { "type": "pressure", "associated_element_id": "J1", "api_query_id": "pressure-query", } ], ) monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp")) monkeypatch.setattr(module, "run_burst_location", lambda **kwargs: {}) with pytest.raises(ValueError, match="simulation_scheme_name"): module.run_burst_location_by_network( network="tjwater", username="testuser", data_source="simulation", burst_leakage=1.0, scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), ) def test_run_burst_location_monitoring_uses_scada_for_burst_and_realtime_for_normal( monkeypatch, tmp_path ): module = _load_burst_location_module() captured = {} scada_calls = [] realtime_calls = [] monkeypatch.setattr( module, "get_all_scada_info", lambda network: [ { "type": "pressure", "associated_element_id": "J1", "api_query_id": "pressure-query", } ], ) monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp")) monkeypatch.setattr( module, "run_burst_location", lambda **kwargs: captured.update(kwargs) or {"located_pipe": "Pipe-001"}, ) def fake_scada_query(**kwargs): scada_calls.append(kwargs) return { "pressure-query": [ {"time": kwargs["start_time"], "value": 20.0}, {"time": kwargs["end_time"], "value": 22.0}, ] } def fake_realtime_query(**kwargs): realtime_calls.append(kwargs) return { "J1": [ {"time": kwargs["start_time"], "value": 10.0}, {"time": kwargs["end_time"], "value": 12.0}, ] } monkeypatch.setattr( module.InternalQueries, "query_scada_by_ids_timerange", staticmethod(fake_scada_query), ) monkeypatch.setattr( module.InternalQueries, "query_realtime_simulation_by_ids_timerange", staticmethod(fake_realtime_query), ) result = module.run_burst_location_by_network( network="tjwater", username="testuser", data_source="monitoring", burst_leakage=1.0, scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), ) assert result["observed_source"] == "scada_burst_realtime_normal_timerange" assert len(scada_calls) == 1 assert len(realtime_calls) == 1 assert captured["burst_pressure"]["J1"] == pytest.approx(21.0) assert captured["normal_pressure"]["J1"] == pytest.approx(11.0)