From 7f481ca2617a6dbd28571a9a9001c3a0c02495b3 Mon Sep 17 00:00:00 2001 From: Jiang Date: Sat, 7 Mar 2026 10:50:25 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=A8=A1=E6=8B=9F=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=BA=90=E6=94=AF=E6=8C=81=EF=BC=8C=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E7=88=86=E7=AE=A1=E5=AE=9A=E4=BD=8D=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/v1/endpoints/burst_location.py | 5 + app/infra/db/timescaledb/internal_queries.py | 158 +++++++++ app/services/burst_location.py | 352 ++++++++++++++++--- tests/unit/test_burst_location_service.py | 210 +++++++++++ 4 files changed, 682 insertions(+), 43 deletions(-) create mode 100644 tests/unit/test_burst_location_service.py diff --git a/app/api/v1/endpoints/burst_location.py b/app/api/v1/endpoints/burst_location.py index 9173690..b3fd6a5 100644 --- a/app/api/v1/endpoints/burst_location.py +++ b/app/api/v1/endpoints/burst_location.py @@ -1,6 +1,8 @@ from typing import Any from datetime import datetime +from typing import Literal + from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel @@ -16,6 +18,7 @@ router = APIRouter() class BurstLocationRequest(BaseModel): network: str + data_source: Literal["monitoring", "simulation"] = "monitoring" pressure_scada_ids: list[str] | None = None burst_pressure: dict[str, float] | list[dict[str, Any]] | None = None normal_pressure: dict[str, float] | list[dict[str, Any]] | None = None @@ -31,6 +34,8 @@ class BurstLocationRequest(BaseModel): scada_normal_end: datetime | None = None use_scada_flow: bool = False scheme_name: str | None = None + simulation_scheme_name: str | None = None + simulation_scheme_type: str | None = None @router.post("/locate/") diff --git a/app/infra/db/timescaledb/internal_queries.py b/app/infra/db/timescaledb/internal_queries.py index 29d735e..ee12d1e 100644 --- a/app/infra/db/timescaledb/internal_queries.py +++ b/app/infra/db/timescaledb/internal_queries.py @@ -3,6 +3,8 @@ from typing import List from fastapi.logger import logger from datetime import datetime, timedelta import psycopg +from psycopg import sql +from psycopg.rows import dict_row import time from app.infra.db.timescaledb.schemas.scheme import SchemeRepository from app.infra.db.timescaledb.schemas.realtime import RealtimeRepository @@ -170,3 +172,159 @@ class InternalQueries: time.sleep(1) else: raise + + @staticmethod + def query_realtime_simulation_by_ids_timerange( + element_ids: List[str], + start_time: str | datetime, + end_time: str | datetime, + element_type: str, + field: str, + db_name: str = None, + max_retries: int = 3, + ) -> dict[str, list[dict]]: + """查询实时模拟结果,返回 {id: [{time, value}, ...]}。""" + return InternalQueries._query_simulation_by_ids_timerange( + schema_name="realtime", + element_ids=element_ids, + start_time=start_time, + end_time=end_time, + element_type=element_type, + field=field, + db_name=db_name, + max_retries=max_retries, + ) + + @staticmethod + def query_scheme_simulation_by_ids_timerange( + element_ids: List[str], + start_time: str | datetime, + end_time: str | datetime, + element_type: str, + field: str, + scheme_type: str, + scheme_name: str, + db_name: str = None, + max_retries: int = 3, + ) -> dict[str, list[dict]]: + """查询方案模拟结果,返回 {id: [{time, value}, ...]}。""" + return InternalQueries._query_simulation_by_ids_timerange( + schema_name="scheme", + element_ids=element_ids, + start_time=start_time, + end_time=end_time, + element_type=element_type, + field=field, + db_name=db_name, + max_retries=max_retries, + scheme_type=scheme_type, + scheme_name=scheme_name, + ) + + @staticmethod + def _query_simulation_by_ids_timerange( + *, + schema_name: str, + element_ids: List[str], + start_time: str | datetime, + end_time: str | datetime, + element_type: str, + field: str, + db_name: str = None, + max_retries: int = 3, + scheme_type: str | None = None, + scheme_name: str | None = None, + ) -> dict[str, list[dict]]: + if not element_ids: + return {} + + start_dt = ( + datetime.fromisoformat(start_time) + if isinstance(start_time, str) + else start_time + ) + end_dt = ( + datetime.fromisoformat(end_time) if isinstance(end_time, str) else end_time + ) + table_name, valid_fields = InternalQueries._resolve_simulation_table(element_type) + if field not in valid_fields: + raise ValueError(f"Invalid field for {element_type}: {field}") + if schema_name not in {"realtime", "scheme"}: + raise ValueError(f"Unsupported schema_name: {schema_name}") + if schema_name == "scheme" and (not scheme_type or not scheme_name): + raise ValueError("scheme 查询必须提供 scheme_type 和 scheme_name。") + + for attempt in range(max_retries): + try: + conn_string = ( + timescaledb_info.get_pgconn_string(db_name=db_name) + if db_name + else timescaledb_info.get_pgconn_string() + ) + with psycopg.Connection.connect(conn_string) as conn: + with conn.cursor(row_factory=dict_row) as cur: + if schema_name == "scheme": + query = sql.SQL( + "SELECT id, time, {} FROM {}.{} " + "WHERE scheme_type = %s AND scheme_name = %s " + "AND time >= %s AND time <= %s AND id = ANY(%s)" + ).format( + sql.Identifier(field), + sql.Identifier(schema_name), + sql.Identifier(table_name), + ) + cur.execute( + query, + ( + scheme_type, + scheme_name, + start_dt, + end_dt, + element_ids, + ), + ) + else: + query = sql.SQL( + "SELECT id, time, {} FROM {}.{} " + "WHERE time >= %s AND time <= %s AND id = ANY(%s)" + ).format( + sql.Identifier(field), + sql.Identifier(schema_name), + sql.Identifier(table_name), + ) + cur.execute(query, (start_dt, end_dt, element_ids)) + rows = cur.fetchall() + result: dict[str, list[dict]] = { + element_id: [] for element_id in element_ids + } + for row in rows: + result.setdefault(row["id"], []).append( + {"time": row["time"].isoformat(), "value": row[field]} + ) + for element_id in result: + result[element_id].sort(key=lambda item: item["time"]) + return result + except Exception as e: + logger.error(f"查询尝试 {attempt + 1} 失败: {e}") + if attempt < max_retries - 1: + time.sleep(1) + else: + raise + + @staticmethod + def _resolve_simulation_table(element_type: str) -> tuple[str, set[str]]: + normalized_type = element_type.lower() + if normalized_type == "node": + return "node_simulation", {"actual_demand", "total_head", "pressure", "quality"} + if normalized_type == "link": + return "link_simulation", { + "flow", + "friction", + "headloss", + "quality", + "reaction", + "setting", + "status", + "velocity", + } + raise ValueError(f"Unsupported element_type: {element_type}") diff --git a/app/services/burst_location.py b/app/services/burst_location.py index 77ee833..eeb1f1f 100644 --- a/app/services/burst_location.py +++ b/app/services/burst_location.py @@ -18,6 +18,8 @@ from app.services.tjnetwork import dump_inp, get_all_scada_info SeriesInput = pd.Series | dict[str, Any] | list[dict[str, Any]] FLOW_SCADA_TYPES = {"pipe_flow", "flow", "demand"} +SIMULATION_DATA_SOURCES = {"monitoring", "simulation"} +DEFAULT_SIMULATION_SCHEME_TYPE = "burst_analysis" def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series: @@ -44,6 +46,7 @@ def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series: def run_burst_location_by_network( *, network: str, + data_source: str = "monitoring", burst_leakage: float, pressure_scada_ids: list[str] | None = None, burst_pressure: SeriesInput | None = None, @@ -59,10 +62,18 @@ def run_burst_location_by_network( scada_normal_end: datetime | str | None = None, use_scada_flow: bool = False, scheme_name: str | None = None, + simulation_scheme_name: str | None = None, + simulation_scheme_type: str | None = None, username: str = "admin", ) -> dict[str, Any]: if not network: raise ValueError("network is required.") + normalized_data_source = _normalize_data_source( + data_source, simulation_scheme_name=simulation_scheme_name + ) + resolved_simulation_scheme_type = ( + simulation_scheme_type or DEFAULT_SIMULATION_SCHEME_TYPE + ) selected_pressure_ids = ( _dedupe_ids(pressure_scada_ids) @@ -93,26 +104,60 @@ def run_burst_location_by_network( scada_normal_start=scada_normal_start, scada_normal_end=scada_normal_end, ) - burst_pressure_series, burst_pressure_samples = _build_observed_series_from_scada( - network=network, - sensor_ids=selected_pressure_ids, - start_dt=burst_start_dt, - end_dt=burst_end_dt, - data_type="pressure", - series_name="burst_pressure", - ) - ( - normal_pressure_series, - normal_pressure_samples, - ) = _build_observed_series_from_scada( - network=network, - sensor_ids=selected_pressure_ids, - start_dt=normal_start_dt, - end_dt=normal_end_dt, - data_type="pressure", - series_name="normal_pressure", - ) - observed_source = "backend_timerange" + if normalized_data_source == "simulation": + if not simulation_scheme_name: + raise ValueError("模拟方案模式必须提供 simulation_scheme_name。") + burst_pressure_series, burst_pressure_samples = ( + _build_observed_series_from_simulation( + network=network, + sensor_ids=selected_pressure_ids, + start_dt=burst_start_dt, + end_dt=burst_end_dt, + data_type="pressure", + series_name="burst_pressure", + simulation_source="scheme", + simulation_scheme_name=simulation_scheme_name, + simulation_scheme_type=resolved_simulation_scheme_type, + ) + ) + ( + normal_pressure_series, + normal_pressure_samples, + ) = _build_observed_series_from_simulation( + network=network, + sensor_ids=selected_pressure_ids, + start_dt=normal_start_dt, + end_dt=normal_end_dt, + data_type="pressure", + series_name="normal_pressure", + simulation_source="scheme", + simulation_scheme_name=simulation_scheme_name, + simulation_scheme_type=resolved_simulation_scheme_type, + ) + observed_source = "simulation_scheme_timerange" + else: + burst_pressure_series, burst_pressure_samples = ( + _build_observed_series_from_scada( + network=network, + sensor_ids=selected_pressure_ids, + start_dt=burst_start_dt, + end_dt=burst_end_dt, + data_type="pressure", + series_name="burst_pressure", + ) + ) + ( + normal_pressure_series, + normal_pressure_samples, + ) = _build_observed_series_from_scada( + network=network, + sensor_ids=selected_pressure_ids, + start_dt=normal_start_dt, + end_dt=normal_end_dt, + data_type="pressure", + series_name="normal_pressure", + ) + observed_source = "backend_timerange" else: if burst_pressure is None or normal_pressure is None: raise ValueError( @@ -139,22 +184,52 @@ def run_burst_location_by_network( ) if not selected_flow_ids: raise ValueError("未找到可用流量传感器,无法从 SCADA 查询流量数据。") - burst_flow_series, burst_flow_samples = _build_observed_series_from_scada( - network=network, - sensor_ids=selected_flow_ids, - start_dt=burst_start_dt, - end_dt=burst_end_dt, - data_type="flow", - series_name="burst_flow", - ) - normal_flow_series, normal_flow_samples = _build_observed_series_from_scada( - network=network, - sensor_ids=selected_flow_ids, - start_dt=normal_start_dt, - end_dt=normal_end_dt, - data_type="flow", - series_name="normal_flow", - ) + if normalized_data_source == "simulation": + if not simulation_scheme_name: + raise ValueError("模拟方案模式必须提供 simulation_scheme_name。") + burst_flow_series, burst_flow_samples = ( + _build_observed_series_from_simulation( + network=network, + sensor_ids=selected_flow_ids, + start_dt=burst_start_dt, + end_dt=burst_end_dt, + data_type="flow", + series_name="burst_flow", + simulation_source="scheme", + simulation_scheme_name=simulation_scheme_name, + simulation_scheme_type=resolved_simulation_scheme_type, + ) + ) + normal_flow_series, normal_flow_samples = ( + _build_observed_series_from_simulation( + network=network, + sensor_ids=selected_flow_ids, + start_dt=normal_start_dt, + end_dt=normal_end_dt, + data_type="flow", + series_name="normal_flow", + simulation_source="scheme", + simulation_scheme_name=simulation_scheme_name, + simulation_scheme_type=resolved_simulation_scheme_type, + ) + ) + else: + burst_flow_series, burst_flow_samples = _build_observed_series_from_scada( + network=network, + sensor_ids=selected_flow_ids, + start_dt=burst_start_dt, + end_dt=burst_end_dt, + data_type="flow", + series_name="burst_flow", + ) + normal_flow_series, normal_flow_samples = _build_observed_series_from_scada( + network=network, + sensor_ids=selected_flow_ids, + start_dt=normal_start_dt, + end_dt=normal_end_dt, + data_type="flow", + series_name="normal_flow", + ) else: if flow_scada_ids is not None: selected_flow_ids = _dedupe_ids(flow_scada_ids) @@ -186,6 +261,7 @@ def run_burst_location_by_network( payload: dict[str, Any] = { **result, "network": network, + "data_source": normalized_data_source, "pressure_scada_ids": selected_pressure_ids, "flow_scada_ids": selected_flow_ids or [], "observed_source": observed_source, @@ -202,6 +278,11 @@ def run_burst_location_by_network( "normal_start": normal_start_dt.isoformat(), "normal_end": normal_end_dt.isoformat(), } + if normalized_data_source == "simulation": + payload["simulation_scheme"] = { + "name": simulation_scheme_name, + "type": resolved_simulation_scheme_type, + } if scheme_name: _store_burst_scheme( network=network, @@ -335,8 +416,174 @@ def _build_observed_series_from_scada( return pd.Series(values, dtype=float), min(sample_counts) -def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]: - mapping: dict[str, str] = {} +def _build_observed_series_from_simulation( + *, + network: str, + sensor_ids: list[str], + start_dt: datetime, + end_dt: datetime, + data_type: str, + series_name: str, + simulation_source: str, + simulation_scheme_name: str | None, + simulation_scheme_type: str, +) -> tuple[pd.Series, int]: + sensor_metadata = _build_sensor_metadata(network=network, data_type=data_type) + missing_ids = [sensor_id for sensor_id in sensor_ids if sensor_id not in sensor_metadata] + if missing_ids: + preview = ", ".join(missing_ids[:10]) + raise ValueError(f"{series_name} 缺少可用 SCADA 映射: {preview}") + + simulation_data = _query_simulation_data_by_sensor_ids( + network=network, + sensor_ids=sensor_ids, + sensor_metadata=sensor_metadata, + start_dt=start_dt, + end_dt=end_dt, + data_type=data_type, + simulation_source=simulation_source, + simulation_scheme_name=simulation_scheme_name, + simulation_scheme_type=simulation_scheme_type, + ) + values: dict[str, float] = {} + sample_counts: list[int] = [] + for sensor_id in sensor_ids: + records = simulation_data.get(sensor_id, []) + numeric_values = [ + float(item["value"]) + for item in records + if item.get("value") is not None + ] + if not numeric_values: + raise ValueError(f"{series_name} 在时间窗内无有效模拟数据: {sensor_id}") + values[sensor_id] = float(sum(numeric_values) / len(numeric_values)) + sample_counts.append(len(numeric_values)) + + return pd.Series(values, dtype=float), min(sample_counts) + + +def _query_simulation_data_by_sensor_ids( + *, + network: str, + sensor_ids: list[str], + sensor_metadata: dict[str, dict[str, str]], + start_dt: datetime, + end_dt: datetime, + data_type: str, + simulation_source: str, + simulation_scheme_name: str | None, + simulation_scheme_type: str, +) -> dict[str, list[dict[str, Any]]]: + if simulation_source not in {"scheme", "realtime"}: + raise ValueError(f"Unsupported simulation_source: {simulation_source}") + + result: dict[str, list[dict[str, Any]]] = {sensor_id: [] for sensor_id in sensor_ids} + if data_type == "pressure": + result.update( + _query_simulation_values( + network=network, + element_ids=sensor_ids, + element_type="node", + field="pressure", + start_dt=start_dt, + end_dt=end_dt, + simulation_source=simulation_source, + simulation_scheme_name=simulation_scheme_name, + simulation_scheme_type=simulation_scheme_type, + ) + ) + return result + + if data_type != "flow": + raise ValueError(f"Unsupported data_type: {data_type}") + + link_ids: list[str] = [] + demand_ids: list[str] = [] + unsupported_ids: list[str] = [] + for sensor_id in sensor_ids: + scada_type = sensor_metadata[sensor_id]["scada_type"] + if scada_type in {"pipe_flow", "flow"}: + link_ids.append(sensor_id) + elif scada_type == "demand": + demand_ids.append(sensor_id) + else: + unsupported_ids.append(f"{sensor_id}({scada_type})") + if unsupported_ids: + preview = ", ".join(unsupported_ids[:10]) + raise ValueError(f"flow 模拟数据暂不支持以下 SCADA 类型: {preview}") + + if link_ids: + result.update( + _query_simulation_values( + network=network, + element_ids=link_ids, + element_type="link", + field="flow", + start_dt=start_dt, + end_dt=end_dt, + simulation_source=simulation_source, + simulation_scheme_name=simulation_scheme_name, + simulation_scheme_type=simulation_scheme_type, + ) + ) + if demand_ids: + result.update( + _query_simulation_values( + network=network, + element_ids=demand_ids, + element_type="node", + field="actual_demand", + start_dt=start_dt, + end_dt=end_dt, + simulation_source=simulation_source, + simulation_scheme_name=simulation_scheme_name, + simulation_scheme_type=simulation_scheme_type, + ) + ) + return result + + +def _query_simulation_values( + *, + network: str, + element_ids: list[str], + element_type: str, + field: str, + start_dt: datetime, + end_dt: datetime, + simulation_source: str, + simulation_scheme_name: str | None, + simulation_scheme_type: str, +) -> dict[str, list[dict[str, Any]]]: + if not element_ids: + return {} + if simulation_source == "scheme": + if not simulation_scheme_name: + raise ValueError("读取方案模拟数据时必须提供 simulation_scheme_name。") + return InternalQueries.query_scheme_simulation_by_ids_timerange( + db_name=network, + scheme_type=simulation_scheme_type, + scheme_name=simulation_scheme_name, + element_ids=element_ids, + start_time=start_dt.isoformat(), + end_time=end_dt.isoformat(), + element_type=element_type, + field=field, + ) + if simulation_source == "realtime": + return InternalQueries.query_realtime_simulation_by_ids_timerange( + db_name=network, + element_ids=element_ids, + start_time=start_dt.isoformat(), + end_time=end_dt.isoformat(), + element_type=element_type, + field=field, + ) + raise ValueError(f"Unsupported simulation_source: {simulation_source}") + + +def _build_sensor_metadata(network: str, data_type: str) -> dict[str, dict[str, str]]: + metadata: dict[str, dict[str, str]] = {} for item in get_all_scada_info(network): scada_type = str(item.get("type", "")).lower() if data_type == "pressure": @@ -347,16 +594,35 @@ def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]: continue else: raise ValueError(f"Unsupported data_type: {data_type}") - node_id = item.get("associated_element_id") + element_id = item.get("associated_element_id") query_id = item.get("api_query_id") if ( - isinstance(node_id, str) - and node_id + isinstance(element_id, str) + and element_id and isinstance(query_id, str) and query_id ): - mapping[node_id] = query_id - return mapping + metadata[element_id] = {"query_id": query_id, "scada_type": scada_type} + return metadata + + +def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]: + metadata = _build_sensor_metadata(network=network, data_type=data_type) + return { + element_id: item["query_id"] for element_id, item in metadata.items() + } + + +def _normalize_data_source( + data_source: str | None, simulation_scheme_name: str | None = None +) -> str: + normalized = str(data_source or "").strip().lower() + if not normalized: + return "simulation" if simulation_scheme_name else "monitoring" + if normalized not in SIMULATION_DATA_SOURCES: + allowed_sources = ", ".join(sorted(SIMULATION_DATA_SOURCES)) + raise ValueError(f"Unsupported data_source: {data_source}. Allowed: {allowed_sources}") + return normalized def _get_sensor_nodes(network: str, data_type: str) -> list[str]: diff --git a/tests/unit/test_burst_location_service.py b/tests/unit/test_burst_location_service.py new file mode 100644 index 0000000..1511634 --- /dev/null +++ b/tests/unit/test_burst_location_service.py @@ -0,0 +1,210 @@ +import importlib.util +import sys +import types +from datetime import datetime, timedelta +from pathlib import Path + +import pytest + + +def _load_burst_location_module(): + module_path = ( + Path(__file__).resolve().parents[2] / "app" / "services" / "burst_location.py" + ) + + def ensure_package(name: str) -> types.ModuleType: + module = sys.modules.get(name) + if module is None: + module = types.ModuleType(name) + module.__path__ = [] + sys.modules[name] = module + return module + + for package_name in [ + "app", + "app.algorithms", + "app.infra", + "app.infra.db", + "app.infra.db.timescaledb", + "app.services", + ]: + ensure_package(package_name) + + algorithms_module = types.ModuleType("app.algorithms.burst_location") + algorithms_module.run_burst_location = lambda **kwargs: {} + sys.modules["app.algorithms.burst_location"] = algorithms_module + + internal_queries_module = types.ModuleType( + "app.infra.db.timescaledb.internal_queries" + ) + + class DummyInternalQueries: + @staticmethod + def query_scada_by_ids_timerange(**kwargs): + return {} + + @staticmethod + def query_scheme_simulation_by_ids_timerange(**kwargs): + return {} + + @staticmethod + def query_realtime_simulation_by_ids_timerange(**kwargs): + return {} + + internal_queries_module.InternalQueries = DummyInternalQueries + sys.modules["app.infra.db.timescaledb.internal_queries"] = internal_queries_module + + scheme_management_module = types.ModuleType("app.services.scheme_management") + scheme_management_module.query_burst_location_scheme_detail = lambda *args, **kwargs: {} + scheme_management_module.query_burst_location_schemes = lambda *args, **kwargs: [] + scheme_management_module.scheme_name_exists = lambda *args, **kwargs: False + scheme_management_module.store_scheme_info = lambda *args, **kwargs: None + sys.modules["app.services.scheme_management"] = scheme_management_module + + tjnetwork_module = types.ModuleType("app.services.tjnetwork") + tjnetwork_module.dump_inp = lambda *args, **kwargs: None + tjnetwork_module.get_all_scada_info = lambda *args, **kwargs: [] + sys.modules["app.services.tjnetwork"] = tjnetwork_module + + module_name = "tests_burst_location_under_test" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) + return module + + +def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, tmp_path): + module = _load_burst_location_module() + captured = {} + scheme_calls = [] + + monkeypatch.setattr( + module, + "get_all_scada_info", + lambda network: [ + { + "type": "pressure", + "associated_element_id": "J1", + "api_query_id": "pressure-query", + }, + { + "type": "pipe_flow", + "associated_element_id": "P1", + "api_query_id": "pipe-flow-query", + }, + { + "type": "demand", + "associated_element_id": "J2", + "api_query_id": "demand-query", + }, + ], + ) + monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp")) + + def fake_run_burst_location(**kwargs): + captured.update(kwargs) + return { + "located_pipe": "Pipe-001", + "simulation_times": 3, + "similarity_mode": "combined", + } + + monkeypatch.setattr(module, "run_burst_location", fake_run_burst_location) + + def _build_series(start_time: str, values: list[float]) -> list[dict]: + base_time = datetime.fromisoformat(start_time) + return [ + { + "time": (base_time + timedelta(minutes=15 * index)).isoformat(), + "value": value, + } + for index, value in enumerate(values) + ] + + def fake_scheme_query(**kwargs): + scheme_calls.append(kwargs) + if kwargs["element_type"] == "node" and kwargs["field"] == "pressure": + start_hour = datetime.fromisoformat(kwargs["start_time"]).hour + values = [12.0, 14.0, 16.0, 18.0] if start_hour == 8 else [8.0, 10.0, 12.0, 14.0] + return {"J1": _build_series(kwargs["start_time"], values)} + if kwargs["element_type"] == "link" and kwargs["field"] == "flow": + start_hour = datetime.fromisoformat(kwargs["start_time"]).hour + values = [5.0, 7.0, 9.0, 11.0] if start_hour == 8 else [2.0, 4.0, 6.0, 8.0] + return {"P1": _build_series(kwargs["start_time"], values)} + if kwargs["element_type"] == "node" and kwargs["field"] == "actual_demand": + start_hour = datetime.fromisoformat(kwargs["start_time"]).hour + values = [3.0, 5.0, 7.0, 9.0] if start_hour == 8 else [1.0, 3.0, 5.0, 7.0] + return {"J2": _build_series(kwargs["start_time"], values)} + raise AssertionError(f"Unexpected scheme query: {kwargs}") + + monkeypatch.setattr( + module.InternalQueries, + "query_scheme_simulation_by_ids_timerange", + staticmethod(fake_scheme_query), + ) + monkeypatch.setattr( + module.InternalQueries, + "query_realtime_simulation_by_ids_timerange", + staticmethod(lambda **kwargs: pytest.fail(f"Unexpected realtime query: {kwargs}")), + ) + + result = module.run_burst_location_by_network( + network="tjwater", + data_source="simulation", + simulation_scheme_name="BurstSchemeA", + simulation_scheme_type="burst_analysis", + burst_leakage=10.0, + scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), + scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), + scada_normal_start=datetime(2025, 1, 1, 6, 0, 0), + scada_normal_end=datetime(2025, 1, 1, 7, 0, 0), + use_scada_flow=True, + ) + + assert result["observed_source"] == "simulation_scheme_timerange" + assert result["simulation_scheme"] == { + "name": "BurstSchemeA", + "type": "burst_analysis", + } + assert result["pressure_samples"] == {"burst": 4, "normal": 4} + assert result["flow_samples"] == {"burst": 4, "normal": 4} + assert list(captured["burst_pressure"].index) == ["J1"] + assert captured["burst_pressure"]["J1"] == pytest.approx(15.0) + assert captured["normal_pressure"]["J1"] == pytest.approx(11.0) + assert captured["burst_flow"]["J2"] == pytest.approx(6.0) + assert captured["burst_flow"]["P1"] == pytest.approx(8.0) + assert captured["normal_flow"]["J2"] == pytest.approx(4.0) + assert captured["normal_flow"]["P1"] == pytest.approx(5.0) + assert all(call["scheme_name"] == "BurstSchemeA" for call in scheme_calls) + assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in scheme_calls) + assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls) + assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls) + + +def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_path): + module = _load_burst_location_module() + monkeypatch.setattr( + module, + "get_all_scada_info", + lambda network: [ + { + "type": "pressure", + "associated_element_id": "J1", + "api_query_id": "pressure-query", + } + ], + ) + monkeypatch.setattr(module, "_prepare_burst_inp", lambda network: str(tmp_path / "fake.inp")) + monkeypatch.setattr(module, "run_burst_location", lambda **kwargs: {}) + + with pytest.raises(ValueError, match="simulation_scheme_name"): + module.run_burst_location_by_network( + network="tjwater", + data_source="simulation", + burst_leakage=1.0, + scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), + scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), + scada_normal_start=datetime(2025, 1, 1, 6, 0, 0), + scada_normal_end=datetime(2025, 1, 1, 7, 0, 0), + )