Files
TJWaterServerBinary/tests/unit/test_burst_location_service.py
T
2026-04-14 14:46:51 +08:00

320 lines
12 KiB
Python

import importlib.util
import sys
import types
from datetime import datetime, timedelta, timezone
from pathlib import Path
import pytest
def _load_burst_location_module():
module_path = (
Path(__file__).resolve().parents[2] / "app" / "services" / "burst_location.py"
)
def ensure_package(name: str) -> types.ModuleType:
module = sys.modules.get(name)
if module is None:
module = types.ModuleType(name)
module.__path__ = []
sys.modules[name] = module
return module
for package_name in [
"app",
"app.algorithms",
"app.infra",
"app.infra.db",
"app.infra.db.timescaledb",
"app.services",
]:
ensure_package(package_name)
time_api_module = types.ModuleType("app.services.time_api")
time_api_module.parse_utc_time = (
lambda value, field_name="datetime": (
value.astimezone(timezone.utc)
if isinstance(value, datetime) and value.tzinfo is not None
else datetime.fromisoformat(value).astimezone(timezone.utc)
)
)
time_api_module.extract_date = (
lambda value, field_name="date": (
value.date()
if isinstance(value, datetime)
else datetime.fromisoformat(value).date()
)
)
time_api_module.utc_now = lambda: datetime.now(timezone.utc)
sys.modules["app.services.time_api"] = time_api_module
algorithms_module = types.ModuleType("app.algorithms.burst_location")
algorithms_module.run_burst_location = lambda **kwargs: {}
sys.modules["app.algorithms.burst_location"] = algorithms_module
internal_queries_module = types.ModuleType(
"app.infra.db.timescaledb.internal_queries"
)
class DummyInternalQueries:
@staticmethod
def query_scada_by_ids_timerange(**kwargs):
return {}
@staticmethod
def query_scheme_simulation_by_ids_timerange(**kwargs):
return {}
@staticmethod
def query_realtime_simulation_by_ids_timerange(**kwargs):
return {}
internal_queries_module.InternalQueries = DummyInternalQueries
sys.modules["app.infra.db.timescaledb.internal_queries"] = internal_queries_module
scheme_management_module = types.ModuleType("app.services.scheme_management")
scheme_management_module.query_burst_location_scheme_detail = lambda *args, **kwargs: {}
scheme_management_module.query_burst_location_schemes = lambda *args, **kwargs: []
scheme_management_module.scheme_name_exists = lambda *args, **kwargs: False
scheme_management_module.store_scheme_info = lambda *args, **kwargs: None
sys.modules["app.services.scheme_management"] = scheme_management_module
tjnetwork_module = types.ModuleType("app.services.tjnetwork")
tjnetwork_module.dump_inp = lambda *args, **kwargs: None
tjnetwork_module.get_all_scada_info = lambda *args, **kwargs: []
sys.modules["app.services.tjnetwork"] = tjnetwork_module
module_name = "tests_burst_location_under_test"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module)
return module
def test_run_burst_location_uses_single_timerange_with_burst_source_split(monkeypatch, tmp_path):
module = _load_burst_location_module()
captured = {}
scheme_calls = []
realtime_calls = []
monkeypatch.setattr(
module,
"get_all_scada_info",
lambda network: [
{
"type": "pressure",
"associated_element_id": "J1",
"api_query_id": "pressure-query",
},
{
"type": "pipe_flow",
"associated_element_id": "P1",
"api_query_id": "pipe-flow-query",
},
{
"type": "demand",
"associated_element_id": "J2",
"api_query_id": "demand-query",
},
],
)
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
def fake_run_burst_location(**kwargs):
captured.update(kwargs)
return {
"located_pipe": "Pipe-001",
"simulation_times": 3,
"similarity_mode": "combined",
}
monkeypatch.setattr(module, "run_burst_location", fake_run_burst_location)
def _build_series(start_time: str, values: list[float]) -> list[dict]:
base_time = datetime.fromisoformat(start_time)
return [
{
"time": (base_time + timedelta(minutes=15 * index)).isoformat(),
"value": value,
}
for index, value in enumerate(values)
]
def fake_scheme_query(**kwargs):
scheme_calls.append(kwargs)
start_hour = datetime.fromisoformat(kwargs["start_time"]).astimezone(
timezone(timedelta(hours=8))
).hour
if kwargs["element_type"] == "node" and kwargs["field"] == "pressure":
values = [12.0, 14.0, 16.0, 18.0] if start_hour == 8 else [8.0, 10.0, 12.0, 14.0]
return {"J1": _build_series(kwargs["start_time"], values)}
if kwargs["element_type"] == "link" and kwargs["field"] == "flow":
values = [5.0, 7.0, 9.0, 11.0] if start_hour == 8 else [2.0, 4.0, 6.0, 8.0]
return {"P1": _build_series(kwargs["start_time"], values)}
if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand":
values = [3.0, 5.0, 7.0, 9.0] if start_hour == 8 else [1.0, 3.0, 5.0, 7.0]
return {"J2": _build_series(kwargs["start_time"], values)}
raise AssertionError(f"Unexpected scheme query: {kwargs}")
def fake_realtime_query(**kwargs):
realtime_calls.append(kwargs)
if kwargs["element_type"] == "node" and kwargs["field"] == "pressure":
return {"J1": _build_series(kwargs["start_time"], [8.0, 10.0, 12.0, 14.0])}
if kwargs["element_type"] == "link" and kwargs["field"] == "flow":
return {"P1": _build_series(kwargs["start_time"], [2.0, 4.0, 6.0, 8.0])}
if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand":
return {"J2": _build_series(kwargs["start_time"], [1.0, 3.0, 5.0, 7.0])}
raise AssertionError(f"Unexpected realtime query: {kwargs}")
monkeypatch.setattr(
module.InternalQueries,
"query_scheme_simulation_by_ids_timerange",
staticmethod(fake_scheme_query),
)
monkeypatch.setattr(
module.InternalQueries,
"query_realtime_simulation_by_ids_timerange",
staticmethod(fake_realtime_query),
)
result = module.run_burst_location_by_network(
network="tjwater",
username="testuser",
data_source="simulation",
simulation_scheme_name="BurstSchemeA",
simulation_scheme_type="burst_analysis",
burst_leakage=10.0,
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone(timedelta(hours=8))),
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone(timedelta(hours=8))),
use_scada_flow=True,
)
assert result["observed_source"] == "simulation_scheme_burst_realtime_normal_timerange"
assert result["simulation_scheme"] == {
"name": "BurstSchemeA",
"type": "burst_analysis",
}
assert result["pressure_samples"] == {"burst": 4, "normal": 4}
assert result["flow_samples"] == {"burst": 4, "normal": 4}
assert list(captured["burst_pressure"].index) == ["J1"]
assert captured["burst_pressure"]["J1"] == pytest.approx(15.0)
assert captured["normal_pressure"]["J1"] == pytest.approx(11.0)
assert captured["burst_flow"]["J2"] == pytest.approx(6.0)
assert captured["burst_flow"]["P1"] == pytest.approx(8.0)
assert captured["normal_flow"]["J2"] == pytest.approx(4.0)
assert captured["normal_flow"]["P1"] == pytest.approx(5.0)
assert all(call["scheme_name"] == "BurstSchemeA" for call in scheme_calls)
assert len(scheme_calls) == 3
assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in scheme_calls)
assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls)
assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls)
assert len(realtime_calls) == 3
assert all(datetime.fromisoformat(call["start_time"]).hour == 0 for call in realtime_calls)
assert all(datetime.fromisoformat(call["end_time"]).hour == 1 for call in realtime_calls)
assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in realtime_calls)
assert any(call["element_type"] == "link" and call["field"] == "flow" for call in realtime_calls)
assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in realtime_calls)
assert result["scada_window"] == {
"burst_start": "2025-01-01T00:00:00+00:00",
"burst_end": "2025-01-01T01:00:00+00:00",
}
def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_path):
module = _load_burst_location_module()
monkeypatch.setattr(
module,
"get_all_scada_info",
lambda network: [
{
"type": "pressure",
"associated_element_id": "J1",
"api_query_id": "pressure-query",
}
],
)
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
monkeypatch.setattr(module, "run_burst_location", lambda **kwargs: {})
with pytest.raises(ValueError, match="simulation_scheme_name"):
module.run_burst_location_by_network(
network="tjwater",
username="testuser",
data_source="simulation",
burst_leakage=1.0,
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone(timedelta(hours=8))),
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone(timedelta(hours=8))),
)
def test_run_burst_location_monitoring_uses_scada_for_burst_and_realtime_for_normal(
monkeypatch, tmp_path
):
module = _load_burst_location_module()
captured = {}
scada_calls = []
realtime_calls = []
monkeypatch.setattr(
module,
"get_all_scada_info",
lambda network: [
{
"type": "pressure",
"associated_element_id": "J1",
"api_query_id": "pressure-query",
}
],
)
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
monkeypatch.setattr(
module,
"run_burst_location",
lambda **kwargs: captured.update(kwargs) or {"located_pipe": "Pipe-001"},
)
def fake_scada_query(**kwargs):
scada_calls.append(kwargs)
return {
"pressure-query": [
{"time": kwargs["start_time"], "value": 20.0},
{"time": kwargs["end_time"], "value": 22.0},
]
}
def fake_realtime_query(**kwargs):
realtime_calls.append(kwargs)
return {
"J1": [
{"time": kwargs["start_time"], "value": 10.0},
{"time": kwargs["end_time"], "value": 12.0},
]
}
monkeypatch.setattr(
module.InternalQueries,
"query_scada_by_ids_timerange",
staticmethod(fake_scada_query),
)
monkeypatch.setattr(
module.InternalQueries,
"query_realtime_simulation_by_ids_timerange",
staticmethod(fake_realtime_query),
)
result = module.run_burst_location_by_network(
network="tjwater",
username="testuser",
data_source="monitoring",
burst_leakage=1.0,
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0, tzinfo=timezone(timedelta(hours=8))),
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0, tzinfo=timezone(timedelta(hours=8))),
)
assert result["observed_source"] == "scada_burst_realtime_normal_timerange"
assert len(scada_calls) == 1
assert len(realtime_calls) == 1
assert captured["burst_pressure"]["J1"] == pytest.approx(21.0)
assert captured["normal_pressure"]["J1"] == pytest.approx(11.0)