新增模拟数据源支持,重构爆管定位逻辑

This commit is contained in:
2026-03-07 10:50:25 +08:00
parent bc74e94fbb
commit 7f481ca261
4 changed files with 682 additions and 43 deletions
+5
View File
@@ -1,6 +1,8 @@
from typing import Any from typing import Any
from datetime import datetime from datetime import datetime
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
@@ -16,6 +18,7 @@ router = APIRouter()
class BurstLocationRequest(BaseModel): class BurstLocationRequest(BaseModel):
network: str network: str
data_source: Literal["monitoring", "simulation"] = "monitoring"
pressure_scada_ids: list[str] | None = None pressure_scada_ids: list[str] | None = None
burst_pressure: dict[str, float] | list[dict[str, Any]] | None = None burst_pressure: dict[str, float] | list[dict[str, Any]] | None = None
normal_pressure: dict[str, float] | list[dict[str, Any]] | None = None normal_pressure: dict[str, float] | list[dict[str, Any]] | None = None
@@ -31,6 +34,8 @@ class BurstLocationRequest(BaseModel):
scada_normal_end: datetime | None = None scada_normal_end: datetime | None = None
use_scada_flow: bool = False use_scada_flow: bool = False
scheme_name: str | None = None scheme_name: str | None = None
simulation_scheme_name: str | None = None
simulation_scheme_type: str | None = None
@router.post("/locate/") @router.post("/locate/")
@@ -3,6 +3,8 @@ from typing import List
from fastapi.logger import logger from fastapi.logger import logger
from datetime import datetime, timedelta from datetime import datetime, timedelta
import psycopg import psycopg
from psycopg import sql
from psycopg.rows import dict_row
import time import time
from app.infra.db.timescaledb.schemas.scheme import SchemeRepository from app.infra.db.timescaledb.schemas.scheme import SchemeRepository
from app.infra.db.timescaledb.schemas.realtime import RealtimeRepository from app.infra.db.timescaledb.schemas.realtime import RealtimeRepository
@@ -170,3 +172,159 @@ class InternalQueries:
time.sleep(1) time.sleep(1)
else: else:
raise raise
@staticmethod
def query_realtime_simulation_by_ids_timerange(
element_ids: List[str],
start_time: str | datetime,
end_time: str | datetime,
element_type: str,
field: str,
db_name: str = None,
max_retries: int = 3,
) -> dict[str, list[dict]]:
"""查询实时模拟结果,返回 {id: [{time, value}, ...]}。"""
return InternalQueries._query_simulation_by_ids_timerange(
schema_name="realtime",
element_ids=element_ids,
start_time=start_time,
end_time=end_time,
element_type=element_type,
field=field,
db_name=db_name,
max_retries=max_retries,
)
@staticmethod
def query_scheme_simulation_by_ids_timerange(
element_ids: List[str],
start_time: str | datetime,
end_time: str | datetime,
element_type: str,
field: str,
scheme_type: str,
scheme_name: str,
db_name: str = None,
max_retries: int = 3,
) -> dict[str, list[dict]]:
"""查询方案模拟结果,返回 {id: [{time, value}, ...]}。"""
return InternalQueries._query_simulation_by_ids_timerange(
schema_name="scheme",
element_ids=element_ids,
start_time=start_time,
end_time=end_time,
element_type=element_type,
field=field,
db_name=db_name,
max_retries=max_retries,
scheme_type=scheme_type,
scheme_name=scheme_name,
)
@staticmethod
def _query_simulation_by_ids_timerange(
*,
schema_name: str,
element_ids: List[str],
start_time: str | datetime,
end_time: str | datetime,
element_type: str,
field: str,
db_name: str = None,
max_retries: int = 3,
scheme_type: str | None = None,
scheme_name: str | None = None,
) -> dict[str, list[dict]]:
if not element_ids:
return {}
start_dt = (
datetime.fromisoformat(start_time)
if isinstance(start_time, str)
else start_time
)
end_dt = (
datetime.fromisoformat(end_time) if isinstance(end_time, str) else end_time
)
table_name, valid_fields = InternalQueries._resolve_simulation_table(element_type)
if field not in valid_fields:
raise ValueError(f"Invalid field for {element_type}: {field}")
if schema_name not in {"realtime", "scheme"}:
raise ValueError(f"Unsupported schema_name: {schema_name}")
if schema_name == "scheme" and (not scheme_type or not scheme_name):
raise ValueError("scheme 查询必须提供 scheme_type 和 scheme_name。")
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
with conn.cursor(row_factory=dict_row) as cur:
if schema_name == "scheme":
query = sql.SQL(
"SELECT id, time, {} FROM {}.{} "
"WHERE scheme_type = %s AND scheme_name = %s "
"AND time >= %s AND time <= %s AND id = ANY(%s)"
).format(
sql.Identifier(field),
sql.Identifier(schema_name),
sql.Identifier(table_name),
)
cur.execute(
query,
(
scheme_type,
scheme_name,
start_dt,
end_dt,
element_ids,
),
)
else:
query = sql.SQL(
"SELECT id, time, {} FROM {}.{} "
"WHERE time >= %s AND time <= %s AND id = ANY(%s)"
).format(
sql.Identifier(field),
sql.Identifier(schema_name),
sql.Identifier(table_name),
)
cur.execute(query, (start_dt, end_dt, element_ids))
rows = cur.fetchall()
result: dict[str, list[dict]] = {
element_id: [] for element_id in element_ids
}
for row in rows:
result.setdefault(row["id"], []).append(
{"time": row["time"].isoformat(), "value": row[field]}
)
for element_id in result:
result[element_id].sort(key=lambda item: item["time"])
return result
except Exception as e:
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1)
else:
raise
@staticmethod
def _resolve_simulation_table(element_type: str) -> tuple[str, set[str]]:
normalized_type = element_type.lower()
if normalized_type == "node":
return "node_simulation", {"actual_demand", "total_head", "pressure", "quality"}
if normalized_type == "link":
return "link_simulation", {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
raise ValueError(f"Unsupported element_type: {element_type}")
+309 -43
View File
@@ -18,6 +18,8 @@ from app.services.tjnetwork import dump_inp, get_all_scada_info
SeriesInput = pd.Series | dict[str, Any] | list[dict[str, Any]] SeriesInput = pd.Series | dict[str, Any] | list[dict[str, Any]]
FLOW_SCADA_TYPES = {"pipe_flow", "flow", "demand"} FLOW_SCADA_TYPES = {"pipe_flow", "flow", "demand"}
SIMULATION_DATA_SOURCES = {"monitoring", "simulation"}
DEFAULT_SIMULATION_SCHEME_TYPE = "burst_analysis"
def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series: def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series:
@@ -44,6 +46,7 @@ def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series:
def run_burst_location_by_network( def run_burst_location_by_network(
*, *,
network: str, network: str,
data_source: str = "monitoring",
burst_leakage: float, burst_leakage: float,
pressure_scada_ids: list[str] | None = None, pressure_scada_ids: list[str] | None = None,
burst_pressure: SeriesInput | None = None, burst_pressure: SeriesInput | None = None,
@@ -59,10 +62,18 @@ def run_burst_location_by_network(
scada_normal_end: datetime | str | None = None, scada_normal_end: datetime | str | None = None,
use_scada_flow: bool = False, use_scada_flow: bool = False,
scheme_name: str | None = None, scheme_name: str | None = None,
simulation_scheme_name: str | None = None,
simulation_scheme_type: str | None = None,
username: str = "admin", username: str = "admin",
) -> dict[str, Any]: ) -> dict[str, Any]:
if not network: if not network:
raise ValueError("network is required.") raise ValueError("network is required.")
normalized_data_source = _normalize_data_source(
data_source, simulation_scheme_name=simulation_scheme_name
)
resolved_simulation_scheme_type = (
simulation_scheme_type or DEFAULT_SIMULATION_SCHEME_TYPE
)
selected_pressure_ids = ( selected_pressure_ids = (
_dedupe_ids(pressure_scada_ids) _dedupe_ids(pressure_scada_ids)
@@ -93,26 +104,60 @@ def run_burst_location_by_network(
scada_normal_start=scada_normal_start, scada_normal_start=scada_normal_start,
scada_normal_end=scada_normal_end, scada_normal_end=scada_normal_end,
) )
burst_pressure_series, burst_pressure_samples = _build_observed_series_from_scada( if normalized_data_source == "simulation":
network=network, if not simulation_scheme_name:
sensor_ids=selected_pressure_ids, raise ValueError("模拟方案模式必须提供 simulation_scheme_name。")
start_dt=burst_start_dt, burst_pressure_series, burst_pressure_samples = (
end_dt=burst_end_dt, _build_observed_series_from_simulation(
data_type="pressure", network=network,
series_name="burst_pressure", sensor_ids=selected_pressure_ids,
) start_dt=burst_start_dt,
( end_dt=burst_end_dt,
normal_pressure_series, data_type="pressure",
normal_pressure_samples, series_name="burst_pressure",
) = _build_observed_series_from_scada( simulation_source="scheme",
network=network, simulation_scheme_name=simulation_scheme_name,
sensor_ids=selected_pressure_ids, simulation_scheme_type=resolved_simulation_scheme_type,
start_dt=normal_start_dt, )
end_dt=normal_end_dt, )
data_type="pressure", (
series_name="normal_pressure", normal_pressure_series,
) normal_pressure_samples,
observed_source = "backend_timerange" ) = _build_observed_series_from_simulation(
network=network,
sensor_ids=selected_pressure_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="pressure",
series_name="normal_pressure",
simulation_source="scheme",
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=resolved_simulation_scheme_type,
)
observed_source = "simulation_scheme_timerange"
else:
burst_pressure_series, burst_pressure_samples = (
_build_observed_series_from_scada(
network=network,
sensor_ids=selected_pressure_ids,
start_dt=burst_start_dt,
end_dt=burst_end_dt,
data_type="pressure",
series_name="burst_pressure",
)
)
(
normal_pressure_series,
normal_pressure_samples,
) = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_pressure_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="pressure",
series_name="normal_pressure",
)
observed_source = "backend_timerange"
else: else:
if burst_pressure is None or normal_pressure is None: if burst_pressure is None or normal_pressure is None:
raise ValueError( raise ValueError(
@@ -139,22 +184,52 @@ def run_burst_location_by_network(
) )
if not selected_flow_ids: if not selected_flow_ids:
raise ValueError("未找到可用流量传感器,无法从 SCADA 查询流量数据。") raise ValueError("未找到可用流量传感器,无法从 SCADA 查询流量数据。")
burst_flow_series, burst_flow_samples = _build_observed_series_from_scada( if normalized_data_source == "simulation":
network=network, if not simulation_scheme_name:
sensor_ids=selected_flow_ids, raise ValueError("模拟方案模式必须提供 simulation_scheme_name。")
start_dt=burst_start_dt, burst_flow_series, burst_flow_samples = (
end_dt=burst_end_dt, _build_observed_series_from_simulation(
data_type="flow", network=network,
series_name="burst_flow", sensor_ids=selected_flow_ids,
) start_dt=burst_start_dt,
normal_flow_series, normal_flow_samples = _build_observed_series_from_scada( end_dt=burst_end_dt,
network=network, data_type="flow",
sensor_ids=selected_flow_ids, series_name="burst_flow",
start_dt=normal_start_dt, simulation_source="scheme",
end_dt=normal_end_dt, simulation_scheme_name=simulation_scheme_name,
data_type="flow", simulation_scheme_type=resolved_simulation_scheme_type,
series_name="normal_flow", )
) )
normal_flow_series, normal_flow_samples = (
_build_observed_series_from_simulation(
network=network,
sensor_ids=selected_flow_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="flow",
series_name="normal_flow",
simulation_source="scheme",
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=resolved_simulation_scheme_type,
)
)
else:
burst_flow_series, burst_flow_samples = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_flow_ids,
start_dt=burst_start_dt,
end_dt=burst_end_dt,
data_type="flow",
series_name="burst_flow",
)
normal_flow_series, normal_flow_samples = _build_observed_series_from_scada(
network=network,
sensor_ids=selected_flow_ids,
start_dt=normal_start_dt,
end_dt=normal_end_dt,
data_type="flow",
series_name="normal_flow",
)
else: else:
if flow_scada_ids is not None: if flow_scada_ids is not None:
selected_flow_ids = _dedupe_ids(flow_scada_ids) selected_flow_ids = _dedupe_ids(flow_scada_ids)
@@ -186,6 +261,7 @@ def run_burst_location_by_network(
payload: dict[str, Any] = { payload: dict[str, Any] = {
**result, **result,
"network": network, "network": network,
"data_source": normalized_data_source,
"pressure_scada_ids": selected_pressure_ids, "pressure_scada_ids": selected_pressure_ids,
"flow_scada_ids": selected_flow_ids or [], "flow_scada_ids": selected_flow_ids or [],
"observed_source": observed_source, "observed_source": observed_source,
@@ -202,6 +278,11 @@ def run_burst_location_by_network(
"normal_start": normal_start_dt.isoformat(), "normal_start": normal_start_dt.isoformat(),
"normal_end": normal_end_dt.isoformat(), "normal_end": normal_end_dt.isoformat(),
} }
if normalized_data_source == "simulation":
payload["simulation_scheme"] = {
"name": simulation_scheme_name,
"type": resolved_simulation_scheme_type,
}
if scheme_name: if scheme_name:
_store_burst_scheme( _store_burst_scheme(
network=network, network=network,
@@ -335,8 +416,174 @@ def _build_observed_series_from_scada(
return pd.Series(values, dtype=float), min(sample_counts) return pd.Series(values, dtype=float), min(sample_counts)
def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]: def _build_observed_series_from_simulation(
mapping: dict[str, str] = {} *,
network: str,
sensor_ids: list[str],
start_dt: datetime,
end_dt: datetime,
data_type: str,
series_name: str,
simulation_source: str,
simulation_scheme_name: str | None,
simulation_scheme_type: str,
) -> tuple[pd.Series, int]:
sensor_metadata = _build_sensor_metadata(network=network, data_type=data_type)
missing_ids = [sensor_id for sensor_id in sensor_ids if sensor_id not in sensor_metadata]
if missing_ids:
preview = ", ".join(missing_ids[:10])
raise ValueError(f"{series_name} 缺少可用 SCADA 映射: {preview}")
simulation_data = _query_simulation_data_by_sensor_ids(
network=network,
sensor_ids=sensor_ids,
sensor_metadata=sensor_metadata,
start_dt=start_dt,
end_dt=end_dt,
data_type=data_type,
simulation_source=simulation_source,
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=simulation_scheme_type,
)
values: dict[str, float] = {}
sample_counts: list[int] = []
for sensor_id in sensor_ids:
records = simulation_data.get(sensor_id, [])
numeric_values = [
float(item["value"])
for item in records
if item.get("value") is not None
]
if not numeric_values:
raise ValueError(f"{series_name} 在时间窗内无有效模拟数据: {sensor_id}")
values[sensor_id] = float(sum(numeric_values) / len(numeric_values))
sample_counts.append(len(numeric_values))
return pd.Series(values, dtype=float), min(sample_counts)
def _query_simulation_data_by_sensor_ids(
*,
network: str,
sensor_ids: list[str],
sensor_metadata: dict[str, dict[str, str]],
start_dt: datetime,
end_dt: datetime,
data_type: str,
simulation_source: str,
simulation_scheme_name: str | None,
simulation_scheme_type: str,
) -> dict[str, list[dict[str, Any]]]:
if simulation_source not in {"scheme", "realtime"}:
raise ValueError(f"Unsupported simulation_source: {simulation_source}")
result: dict[str, list[dict[str, Any]]] = {sensor_id: [] for sensor_id in sensor_ids}
if data_type == "pressure":
result.update(
_query_simulation_values(
network=network,
element_ids=sensor_ids,
element_type="node",
field="pressure",
start_dt=start_dt,
end_dt=end_dt,
simulation_source=simulation_source,
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=simulation_scheme_type,
)
)
return result
if data_type != "flow":
raise ValueError(f"Unsupported data_type: {data_type}")
link_ids: list[str] = []
demand_ids: list[str] = []
unsupported_ids: list[str] = []
for sensor_id in sensor_ids:
scada_type = sensor_metadata[sensor_id]["scada_type"]
if scada_type in {"pipe_flow", "flow"}:
link_ids.append(sensor_id)
elif scada_type == "demand":
demand_ids.append(sensor_id)
else:
unsupported_ids.append(f"{sensor_id}({scada_type})")
if unsupported_ids:
preview = ", ".join(unsupported_ids[:10])
raise ValueError(f"flow 模拟数据暂不支持以下 SCADA 类型: {preview}")
if link_ids:
result.update(
_query_simulation_values(
network=network,
element_ids=link_ids,
element_type="link",
field="flow",
start_dt=start_dt,
end_dt=end_dt,
simulation_source=simulation_source,
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=simulation_scheme_type,
)
)
if demand_ids:
result.update(
_query_simulation_values(
network=network,
element_ids=demand_ids,
element_type="node",
field="actual_demand",
start_dt=start_dt,
end_dt=end_dt,
simulation_source=simulation_source,
simulation_scheme_name=simulation_scheme_name,
simulation_scheme_type=simulation_scheme_type,
)
)
return result
def _query_simulation_values(
*,
network: str,
element_ids: list[str],
element_type: str,
field: str,
start_dt: datetime,
end_dt: datetime,
simulation_source: str,
simulation_scheme_name: str | None,
simulation_scheme_type: str,
) -> dict[str, list[dict[str, Any]]]:
if not element_ids:
return {}
if simulation_source == "scheme":
if not simulation_scheme_name:
raise ValueError("读取方案模拟数据时必须提供 simulation_scheme_name。")
return InternalQueries.query_scheme_simulation_by_ids_timerange(
db_name=network,
scheme_type=simulation_scheme_type,
scheme_name=simulation_scheme_name,
element_ids=element_ids,
start_time=start_dt.isoformat(),
end_time=end_dt.isoformat(),
element_type=element_type,
field=field,
)
if simulation_source == "realtime":
return InternalQueries.query_realtime_simulation_by_ids_timerange(
db_name=network,
element_ids=element_ids,
start_time=start_dt.isoformat(),
end_time=end_dt.isoformat(),
element_type=element_type,
field=field,
)
raise ValueError(f"Unsupported simulation_source: {simulation_source}")
def _build_sensor_metadata(network: str, data_type: str) -> dict[str, dict[str, str]]:
metadata: dict[str, dict[str, str]] = {}
for item in get_all_scada_info(network): for item in get_all_scada_info(network):
scada_type = str(item.get("type", "")).lower() scada_type = str(item.get("type", "")).lower()
if data_type == "pressure": if data_type == "pressure":
@@ -347,16 +594,35 @@ def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]:
continue continue
else: else:
raise ValueError(f"Unsupported data_type: {data_type}") raise ValueError(f"Unsupported data_type: {data_type}")
node_id = item.get("associated_element_id") element_id = item.get("associated_element_id")
query_id = item.get("api_query_id") query_id = item.get("api_query_id")
if ( if (
isinstance(node_id, str) isinstance(element_id, str)
and node_id and element_id
and isinstance(query_id, str) and isinstance(query_id, str)
and query_id and query_id
): ):
mapping[node_id] = query_id metadata[element_id] = {"query_id": query_id, "scada_type": scada_type}
return mapping return metadata
def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]:
metadata = _build_sensor_metadata(network=network, data_type=data_type)
return {
element_id: item["query_id"] for element_id, item in metadata.items()
}
def _normalize_data_source(
data_source: str | None, simulation_scheme_name: str | None = None
) -> str:
normalized = str(data_source or "").strip().lower()
if not normalized:
return "simulation" if simulation_scheme_name else "monitoring"
if normalized not in SIMULATION_DATA_SOURCES:
allowed_sources = ", ".join(sorted(SIMULATION_DATA_SOURCES))
raise ValueError(f"Unsupported data_source: {data_source}. Allowed: {allowed_sources}")
return normalized
def _get_sensor_nodes(network: str, data_type: str) -> list[str]: def _get_sensor_nodes(network: str, data_type: str) -> list[str]:
+210
View File
@@ -0,0 +1,210 @@
import importlib.util
import sys
import types
from datetime import datetime, timedelta
from pathlib import Path
import pytest
def _load_burst_location_module():
module_path = (
Path(__file__).resolve().parents[2] / "app" / "services" / "burst_location.py"
)
def ensure_package(name: str) -> types.ModuleType:
module = sys.modules.get(name)
if module is None:
module = types.ModuleType(name)
module.__path__ = []
sys.modules[name] = module
return module
for package_name in [
"app",
"app.algorithms",
"app.infra",
"app.infra.db",
"app.infra.db.timescaledb",
"app.services",
]:
ensure_package(package_name)
algorithms_module = types.ModuleType("app.algorithms.burst_location")
algorithms_module.run_burst_location = lambda **kwargs: {}
sys.modules["app.algorithms.burst_location"] = algorithms_module
internal_queries_module = types.ModuleType(
"app.infra.db.timescaledb.internal_queries"
)
class DummyInternalQueries:
@staticmethod
def query_scada_by_ids_timerange(**kwargs):
return {}
@staticmethod
def query_scheme_simulation_by_ids_timerange(**kwargs):
return {}
@staticmethod
def query_realtime_simulation_by_ids_timerange(**kwargs):
return {}
internal_queries_module.InternalQueries = DummyInternalQueries
sys.modules["app.infra.db.timescaledb.internal_queries"] = internal_queries_module
scheme_management_module = types.ModuleType("app.services.scheme_management")
scheme_management_module.query_burst_location_scheme_detail = lambda *args, **kwargs: {}
scheme_management_module.query_burst_location_schemes = lambda *args, **kwargs: []
scheme_management_module.scheme_name_exists = lambda *args, **kwargs: False
scheme_management_module.store_scheme_info = lambda *args, **kwargs: None
sys.modules["app.services.scheme_management"] = scheme_management_module
tjnetwork_module = types.ModuleType("app.services.tjnetwork")
tjnetwork_module.dump_inp = lambda *args, **kwargs: None
tjnetwork_module.get_all_scada_info = lambda *args, **kwargs: []
sys.modules["app.services.tjnetwork"] = tjnetwork_module
module_name = "tests_burst_location_under_test"
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module)
return module
def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, tmp_path):
module = _load_burst_location_module()
captured = {}
scheme_calls = []
monkeypatch.setattr(
module,
"get_all_scada_info",
lambda network: [
{
"type": "pressure",
"associated_element_id": "J1",
"api_query_id": "pressure-query",
},
{
"type": "pipe_flow",
"associated_element_id": "P1",
"api_query_id": "pipe-flow-query",
},
{
"type": "demand",
"associated_element_id": "J2",
"api_query_id": "demand-query",
},
],
)
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
def fake_run_burst_location(**kwargs):
captured.update(kwargs)
return {
"located_pipe": "Pipe-001",
"simulation_times": 3,
"similarity_mode": "combined",
}
monkeypatch.setattr(module, "run_burst_location", fake_run_burst_location)
def _build_series(start_time: str, values: list[float]) -> list[dict]:
base_time = datetime.fromisoformat(start_time)
return [
{
"time": (base_time + timedelta(minutes=15 * index)).isoformat(),
"value": value,
}
for index, value in enumerate(values)
]
def fake_scheme_query(**kwargs):
scheme_calls.append(kwargs)
if kwargs["element_type"] == "node" and kwargs["field"] == "pressure":
start_hour = datetime.fromisoformat(kwargs["start_time"]).hour
values = [12.0, 14.0, 16.0, 18.0] if start_hour == 8 else [8.0, 10.0, 12.0, 14.0]
return {"J1": _build_series(kwargs["start_time"], values)}
if kwargs["element_type"] == "link" and kwargs["field"] == "flow":
start_hour = datetime.fromisoformat(kwargs["start_time"]).hour
values = [5.0, 7.0, 9.0, 11.0] if start_hour == 8 else [2.0, 4.0, 6.0, 8.0]
return {"P1": _build_series(kwargs["start_time"], values)}
if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand":
start_hour = datetime.fromisoformat(kwargs["start_time"]).hour
values = [3.0, 5.0, 7.0, 9.0] if start_hour == 8 else [1.0, 3.0, 5.0, 7.0]
return {"J2": _build_series(kwargs["start_time"], values)}
raise AssertionError(f"Unexpected scheme query: {kwargs}")
monkeypatch.setattr(
module.InternalQueries,
"query_scheme_simulation_by_ids_timerange",
staticmethod(fake_scheme_query),
)
monkeypatch.setattr(
module.InternalQueries,
"query_realtime_simulation_by_ids_timerange",
staticmethod(lambda **kwargs: pytest.fail(f"Unexpected realtime query: {kwargs}")),
)
result = module.run_burst_location_by_network(
network="tjwater",
data_source="simulation",
simulation_scheme_name="BurstSchemeA",
simulation_scheme_type="burst_analysis",
burst_leakage=10.0,
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0),
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0),
scada_normal_start=datetime(2025, 1, 1, 6, 0, 0),
scada_normal_end=datetime(2025, 1, 1, 7, 0, 0),
use_scada_flow=True,
)
assert result["observed_source"] == "simulation_scheme_timerange"
assert result["simulation_scheme"] == {
"name": "BurstSchemeA",
"type": "burst_analysis",
}
assert result["pressure_samples"] == {"burst": 4, "normal": 4}
assert result["flow_samples"] == {"burst": 4, "normal": 4}
assert list(captured["burst_pressure"].index) == ["J1"]
assert captured["burst_pressure"]["J1"] == pytest.approx(15.0)
assert captured["normal_pressure"]["J1"] == pytest.approx(11.0)
assert captured["burst_flow"]["J2"] == pytest.approx(6.0)
assert captured["burst_flow"]["P1"] == pytest.approx(8.0)
assert captured["normal_flow"]["J2"] == pytest.approx(4.0)
assert captured["normal_flow"]["P1"] == pytest.approx(5.0)
assert all(call["scheme_name"] == "BurstSchemeA" for call in scheme_calls)
assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in scheme_calls)
assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls)
assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls)
def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_path):
module = _load_burst_location_module()
monkeypatch.setattr(
module,
"get_all_scada_info",
lambda network: [
{
"type": "pressure",
"associated_element_id": "J1",
"api_query_id": "pressure-query",
}
],
)
monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp"))
monkeypatch.setattr(module, "run_burst_location", lambda **kwargs: {})
with pytest.raises(ValueError, match="simulation_scheme_name"):
module.run_burst_location_by_network(
network="tjwater",
data_source="simulation",
burst_leakage=1.0,
scada_burst_start=datetime(2025, 1, 1, 8, 0, 0),
scada_burst_end=datetime(2025, 1, 1, 9, 0, 0),
scada_normal_start=datetime(2025, 1, 1, 6, 0, 0),
scada_normal_end=datetime(2025, 1, 1, 7, 0, 0),
)