320 lines
12 KiB
Python
320 lines
12 KiB
Python
import importlib.util
|
|
import sys
|
|
import types
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
|
|
def _load_burst_location_module():
|
|
module_path = (
|
|
Path(__file__).resolve().parents[2] / "app" / "services" / "burst_location.py"
|
|
)
|
|
|
|
def ensure_package(name: str) -> types.ModuleType:
|
|
module = sys.modules.get(name)
|
|
if module is None:
|
|
module = types.ModuleType(name)
|
|
module.__path__ = []
|
|
sys.modules[name] = module
|
|
return module
|
|
|
|
for package_name in [
|
|
"app",
|
|
"app.algorithms",
|
|
"app.infra",
|
|
"app.infra.db",
|
|
"app.infra.db.timescaledb",
|
|
"app.services",
|
|
]:
|
|
ensure_package(package_name)
|
|
|
|
time_api_module = types.ModuleType("app.services.time_api")
|
|
time_api_module.parse_utc_time = (
|
|
lambda value, field_name="datetime": (
|
|
value.astimezone(timezone.utc)
|
|
if isinstance(value, datetime) and value.tzinfo is not None
|
|
else datetime.fromisoformat(value).astimezone(timezone.utc)
|
|
)
|
|
)
|
|
time_api_module.extract_date = (
|
|
lambda value, field_name="date": (
|
|
value.date()
|
|
if isinstance(value, datetime)
|
|
else datetime.fromisoformat(value).date()
|
|
)
|
|
)
|
|
time_api_module.utc_now = lambda: datetime.now(timezone.utc)
|
|
sys.modules["app.services.time_api"] = time_api_module
|
|
|
|
algorithms_module = types.ModuleType("app.algorithms.burst_location")
|
|
algorithms_module.run_burst_location = lambda **kwargs: {}
|
|
sys.modules["app.algorithms.burst_location"] = algorithms_module
|
|
|
|
internal_queries_module = types.ModuleType(
|
|
"app.infra.db.timescaledb.internal_queries"
|
|
)
|
|
|
|
class DummyInternalQueries:
|
|
@staticmethod
|
|
def query_scada_by_ids_timerange(**kwargs):
|
|
return {}
|
|
|
|
@staticmethod
|
|
def query_scheme_simulation_by_ids_timerange(**kwargs):
|
|
return {}
|
|
|
|
@staticmethod
|
|
def query_realtime_simulation_by_ids_timerange(**kwargs):
|
|
return {}
|
|
|
|
internal_queries_module.InternalQueries = DummyInternalQueries
|
|
sys.modules["app.infra.db.timescaledb.internal_queries"] = internal_queries_module
|
|
|
|
scheme_management_module = types.ModuleType("app.services.scheme_management")
|
|
scheme_management_module.query_burst_location_scheme_detail = lambda *args, **kwargs: {}
|
|
scheme_management_module.query_burst_location_schemes = lambda *args, **kwargs: []
|
|
scheme_management_module.scheme_name_exists = lambda *args, **kwargs: False
|
|
scheme_management_module.store_scheme_info = lambda *args, **kwargs: None
|
|
sys.modules["app.services.scheme_management"] = scheme_management_module
|
|
|
|
tjnetwork_module = types.ModuleType("app.services.tjnetwork")
|
|
tjnetwork_module.dump_inp = lambda *args, **kwargs: None
|
|
tjnetwork_module.get_all_scada_info = lambda *args, **kwargs: []
|
|
sys.modules["app.services.tjnetwork"] = tjnetwork_module
|
|
|
|
module_name = "tests_burst_location_under_test"
|
|
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
assert spec and spec.loader
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def test_run_burst_location_uses_single_timerange_with_burst_source_split(monkeypatch, tmp_path):
|
|
module = _load_burst_location_module()
|
|
captured = {}
|
|
scheme_calls = []
|
|
realtime_calls = []
|
|
|
|
monkeypatch.setattr(
|
|
module,
|
|
"get_all_scada_info",
|
|
lambda network: [
|
|
{
|
|
"type": "pressure",
|
|
"associated_element_id": "J1",
|
|
"api_query_id": "pressure-query",
|
|
},
|
|
{
|
|
"type": "pipe_flow",
|
|
"associated_element_id": "P1",
|
|
"api_query_id": "pipe-flow-query",
|
|
},
|
|
{
|
|
"type": "demand",
|
|
"associated_element_id": "J2",
|
|
"api_query_id": "demand-query",
|
|
},
|
|
],
|
|
)
|
|
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
|
|
|
|
def fake_run_burst_location(**kwargs):
|
|
captured.update(kwargs)
|
|
return {
|
|
"located_pipe": "Pipe-001",
|
|
"simulation_times": 3,
|
|
"similarity_mode": "combined",
|
|
}
|
|
|
|
monkeypatch.setattr(module, "run_burst_location", fake_run_burst_location)
|
|
|
|
def _build_series(start_time: str, values: list[float]) -> list[dict]:
|
|
base_time = datetime.fromisoformat(start_time)
|
|
return [
|
|
{
|
|
"time": (base_time + timedelta(minutes=15 * index)).isoformat(),
|
|
"value": value,
|
|
}
|
|
for index, value in enumerate(values)
|
|
]
|
|
|
|
def fake_scheme_query(**kwargs):
|
|
scheme_calls.append(kwargs)
|
|
start_hour = datetime.fromisoformat(kwargs["start_time"]).astimezone(
|
|
timezone(timedelta(hours=8))
|
|
).hour
|
|
if kwargs["element_type"] == "node" and kwargs["field"] == "pressure":
|
|
values = [12.0, 14.0, 16.0, 18.0] if start_hour == 8 else [8.0, 10.0, 12.0, 14.0]
|
|
return {"J1": _build_series(kwargs["start_time"], values)}
|
|
if kwargs["element_type"] == "link" and kwargs["field"] == "flow":
|
|
values = [5.0, 7.0, 9.0, 11.0] if start_hour == 8 else [2.0, 4.0, 6.0, 8.0]
|
|
return {"P1": _build_series(kwargs["start_time"], values)}
|
|
if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand":
|
|
values = [3.0, 5.0, 7.0, 9.0] if start_hour == 8 else [1.0, 3.0, 5.0, 7.0]
|
|
return {"J2": _build_series(kwargs["start_time"], values)}
|
|
raise AssertionError(f"Unexpected scheme query: {kwargs}")
|
|
|
|
def fake_realtime_query(**kwargs):
|
|
realtime_calls.append(kwargs)
|
|
if kwargs["element_type"] == "node" and kwargs["field"] == "pressure":
|
|
return {"J1": _build_series(kwargs["start_time"], [8.0, 10.0, 12.0, 14.0])}
|
|
if kwargs["element_type"] == "link" and kwargs["field"] == "flow":
|
|
return {"P1": _build_series(kwargs["start_time"], [2.0, 4.0, 6.0, 8.0])}
|
|
if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand":
|
|
return {"J2": _build_series(kwargs["start_time"], [1.0, 3.0, 5.0, 7.0])}
|
|
raise AssertionError(f"Unexpected realtime query: {kwargs}")
|
|
|
|
monkeypatch.setattr(
|
|
module.InternalQueries,
|
|
"query_scheme_simulation_by_ids_timerange",
|
|
staticmethod(fake_scheme_query),
|
|
)
|
|
monkeypatch.setattr(
|
|
module.InternalQueries,
|
|
"query_realtime_simulation_by_ids_timerange",
|
|
staticmethod(fake_realtime_query),
|
|
)
|
|
|
|
result = module.run_burst_location_by_network(
|
|
network="tjwater",
|
|
username="testuser",
|
|
data_source="simulation",
|
|
simulation_scheme_name="BurstSchemeA",
|
|
simulation_scheme_type="burst_analysis",
|
|
burst_leakage=10.0,
|
|
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone(timedelta(hours=8))),
|
|
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone(timedelta(hours=8))),
|
|
use_scada_flow=True,
|
|
)
|
|
|
|
assert result["observed_source"] == "simulation_scheme_burst_realtime_normal_timerange"
|
|
assert result["simulation_scheme"] == {
|
|
"name": "BurstSchemeA",
|
|
"type": "burst_analysis",
|
|
}
|
|
assert result["pressure_samples"] == {"burst": 4, "normal": 4}
|
|
assert result["flow_samples"] == {"burst": 4, "normal": 4}
|
|
assert list(captured["burst_pressure"].index) == ["J1"]
|
|
assert captured["burst_pressure"]["J1"] == pytest.approx(15.0)
|
|
assert captured["normal_pressure"]["J1"] == pytest.approx(11.0)
|
|
assert captured["burst_flow"]["J2"] == pytest.approx(6.0)
|
|
assert captured["burst_flow"]["P1"] == pytest.approx(8.0)
|
|
assert captured["normal_flow"]["J2"] == pytest.approx(4.0)
|
|
assert captured["normal_flow"]["P1"] == pytest.approx(5.0)
|
|
assert all(call["scheme_name"] == "BurstSchemeA" for call in scheme_calls)
|
|
assert len(scheme_calls) == 3
|
|
assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in scheme_calls)
|
|
assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls)
|
|
assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls)
|
|
assert len(realtime_calls) == 3
|
|
assert all(datetime.fromisoformat(call["start_time"]).hour == 0 for call in realtime_calls)
|
|
assert all(datetime.fromisoformat(call["end_time"]).hour == 1 for call in realtime_calls)
|
|
assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in realtime_calls)
|
|
assert any(call["element_type"] == "link" and call["field"] == "flow" for call in realtime_calls)
|
|
assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in realtime_calls)
|
|
assert result["scada_window"] == {
|
|
"burst_start": "2025-01-01T00:00:00+00:00",
|
|
"burst_end": "2025-01-01T01:00:00+00:00",
|
|
}
|
|
|
|
|
|
def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_path):
|
|
module = _load_burst_location_module()
|
|
monkeypatch.setattr(
|
|
module,
|
|
"get_all_scada_info",
|
|
lambda network: [
|
|
{
|
|
"type": "pressure",
|
|
"associated_element_id": "J1",
|
|
"api_query_id": "pressure-query",
|
|
}
|
|
],
|
|
)
|
|
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
|
|
monkeypatch.setattr(module, "run_burst_location", lambda **kwargs: {})
|
|
|
|
with pytest.raises(ValueError, match="simulation_scheme_name"):
|
|
module.run_burst_location_by_network(
|
|
network="tjwater",
|
|
username="testuser",
|
|
data_source="simulation",
|
|
burst_leakage=1.0,
|
|
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone(timedelta(hours=8))),
|
|
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone(timedelta(hours=8))),
|
|
)
|
|
|
|
|
|
def test_run_burst_location_monitoring_uses_scada_for_burst_and_realtime_for_normal(
|
|
monkeypatch, tmp_path
|
|
):
|
|
module = _load_burst_location_module()
|
|
captured = {}
|
|
scada_calls = []
|
|
realtime_calls = []
|
|
|
|
monkeypatch.setattr(
|
|
module,
|
|
"get_all_scada_info",
|
|
lambda network: [
|
|
{
|
|
"type": "pressure",
|
|
"associated_element_id": "J1",
|
|
"api_query_id": "pressure-query",
|
|
}
|
|
],
|
|
)
|
|
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
|
|
monkeypatch.setattr(
|
|
module,
|
|
"run_burst_location",
|
|
lambda **kwargs: captured.update(kwargs) or {"located_pipe": "Pipe-001"},
|
|
)
|
|
|
|
def fake_scada_query(**kwargs):
|
|
scada_calls.append(kwargs)
|
|
return {
|
|
"pressure-query": [
|
|
{"time": kwargs["start_time"], "value": 20.0},
|
|
{"time": kwargs["end_time"], "value": 22.0},
|
|
]
|
|
}
|
|
|
|
def fake_realtime_query(**kwargs):
|
|
realtime_calls.append(kwargs)
|
|
return {
|
|
"J1": [
|
|
{"time": kwargs["start_time"], "value": 10.0},
|
|
{"time": kwargs["end_time"], "value": 12.0},
|
|
]
|
|
}
|
|
|
|
monkeypatch.setattr(
|
|
module.InternalQueries,
|
|
"query_scada_by_ids_timerange",
|
|
staticmethod(fake_scada_query),
|
|
)
|
|
monkeypatch.setattr(
|
|
module.InternalQueries,
|
|
"query_realtime_simulation_by_ids_timerange",
|
|
staticmethod(fake_realtime_query),
|
|
)
|
|
|
|
result = module.run_burst_location_by_network(
|
|
network="tjwater",
|
|
username="testuser",
|
|
data_source="monitoring",
|
|
burst_leakage=1.0,
|
|
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone(timedelta(hours=8))),
|
|
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone(timedelta(hours=8))),
|
|
)
|
|
|
|
assert result["observed_source"] == "scada_burst_realtime_normal_timerange"
|
|
assert len(scada_calls) == 1
|
|
assert len(realtime_calls) == 1
|
|
assert captured["burst_pressure"]["J1"] == pytest.approx(21.0)
|
|
assert captured["normal_pressure"]["J1"] == pytest.approx(11.0)
|