from __future__ import annotations import os from datetime import datetime from typing import Any import pandas as pd from app.algorithms.burst_location import run_burst_location from app.infra.db.timescaledb.internal_queries import InternalQueries from app.services.scheme_management import ( query_burst_location_scheme_detail, query_burst_location_schemes, scheme_name_exists, store_scheme_info, ) from app.services.tjnetwork import dump_inp, get_all_scada_info SeriesInput = pd.Series | dict[str, Any] | list[dict[str, Any]] FLOW_SCADA_TYPES = {"pipe_flow", "flow", "demand"} def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series: if isinstance(data, pd.Series): series = data.copy() elif isinstance(data, dict): series = pd.Series(data, dtype=float) elif isinstance(data, list): if len(data) == 0: return pd.Series(dtype=float) frame = pd.DataFrame(data) if not {"id", "value"}.issubset(frame.columns): raise ValueError(f"{field_name} list item must include 'id' and 'value'.") series = pd.Series( frame["value"].values, index=frame["id"].astype(str).values, dtype=float ) else: raise ValueError(f"Unsupported data format for {field_name}.") series.index = series.index.map(str) return pd.to_numeric(series, errors="raise") def run_burst_location_by_network( *, network: str, burst_leakage: float, pressure_scada_ids: list[str] | None = None, burst_pressure: SeriesInput | None = None, normal_pressure: SeriesInput | None = None, flow_scada_ids: list[str] | None = None, burst_flow: SeriesInput | None = None, normal_flow: SeriesInput | None = None, min_dpressure: float = 2.0, basic_pressure: float = 10.0, scada_burst_start: datetime | str | None = None, scada_burst_end: datetime | str | None = None, scada_normal_start: datetime | str | None = None, scada_normal_end: datetime | str | None = None, use_scada_flow: bool = False, scheme_name: str | None = None, username: str = "admin", ) -> dict[str, Any]: if not network: raise ValueError("network is required.") selected_pressure_ids = ( _dedupe_ids(pressure_scada_ids) if pressure_scada_ids else _get_sensor_nodes(network, data_type="pressure") ) if not selected_pressure_ids: raise ValueError("未提供有效压力传感器,且系统未识别到可用压力传感器。") use_scada_pressure = any( value is not None for value in [ scada_burst_start, scada_burst_end, scada_normal_start, scada_normal_end, ] ) if use_scada_pressure: ( burst_start_dt, burst_end_dt, normal_start_dt, normal_end_dt, ) = _validate_scada_windows( scada_burst_start=scada_burst_start, scada_burst_end=scada_burst_end, scada_normal_start=scada_normal_start, scada_normal_end=scada_normal_end, ) burst_pressure_series, burst_pressure_samples = _build_observed_series_from_scada( network=network, sensor_ids=selected_pressure_ids, start_dt=burst_start_dt, end_dt=burst_end_dt, data_type="pressure", series_name="burst_pressure", ) ( normal_pressure_series, normal_pressure_samples, ) = _build_observed_series_from_scada( network=network, sensor_ids=selected_pressure_ids, start_dt=normal_start_dt, end_dt=normal_end_dt, data_type="pressure", series_name="normal_pressure", ) observed_source = "backend_timerange" else: if burst_pressure is None or normal_pressure is None: raise ValueError( "未提供 burst_pressure/normal_pressure,且未提供完整 SCADA 时间窗参数。" ) burst_pressure_series = _normalize_series(burst_pressure, "burst_pressure") normal_pressure_series = _normalize_series(normal_pressure, "normal_pressure") burst_pressure_samples = 1 normal_pressure_samples = 1 observed_source = "request_payload" burst_start_dt = burst_end_dt = normal_start_dt = normal_end_dt = None selected_flow_ids: list[str] | None = None burst_flow_series: pd.Series | None = None normal_flow_series: pd.Series | None = None use_flow_scada_source = use_scada_pressure and ( use_scada_flow or flow_scada_ids is not None ) if use_flow_scada_source: selected_flow_ids = ( _dedupe_ids(flow_scada_ids) if flow_scada_ids is not None else _get_sensor_nodes(network, data_type="flow") ) if not selected_flow_ids: raise ValueError("未找到可用流量传感器,无法从 SCADA 查询流量数据。") burst_flow_series, burst_flow_samples = _build_observed_series_from_scada( network=network, sensor_ids=selected_flow_ids, start_dt=burst_start_dt, end_dt=burst_end_dt, data_type="flow", series_name="burst_flow", ) normal_flow_series, normal_flow_samples = _build_observed_series_from_scada( network=network, sensor_ids=selected_flow_ids, start_dt=normal_start_dt, end_dt=normal_end_dt, data_type="flow", series_name="normal_flow", ) else: if flow_scada_ids is not None: selected_flow_ids = _dedupe_ids(flow_scada_ids) burst_flow_series = ( _normalize_series(burst_flow, "burst_flow") if burst_flow is not None else None ) normal_flow_series = ( _normalize_series(normal_flow, "normal_flow") if normal_flow is not None else None ) burst_flow_samples = 1 if burst_flow_series is not None else 0 normal_flow_samples = 1 if normal_flow_series is not None else 0 inp_path = _prepare_burst_inp(network) result = run_burst_location( wn_inp_path=inp_path, pressure_scada_ids=selected_pressure_ids, burst_pressure=burst_pressure_series, normal_pressure=normal_pressure_series, burst_leakage=burst_leakage, flow_scada_ids=selected_flow_ids, burst_flow=burst_flow_series, normal_flow=normal_flow_series, min_dpressure=min_dpressure, basic_pressure=basic_pressure, ) payload: dict[str, Any] = { **result, "network": network, "pressure_scada_ids": selected_pressure_ids, "flow_scada_ids": selected_flow_ids or [], "observed_source": observed_source, "pressure_samples": { "burst": burst_pressure_samples, "normal": normal_pressure_samples, }, "flow_samples": {"burst": burst_flow_samples, "normal": normal_flow_samples}, } if use_scada_pressure: payload["scada_window"] = { "burst_start": burst_start_dt.isoformat(), "burst_end": burst_end_dt.isoformat(), "normal_start": normal_start_dt.isoformat(), "normal_end": normal_end_dt.isoformat(), } if scheme_name: _store_burst_scheme( network=network, scheme_name=scheme_name, username=username, payload=payload, burst_leakage=burst_leakage, min_dpressure=min_dpressure, basic_pressure=basic_pressure, ) return payload def list_burst_location_schemes( network: str, query_date: datetime | str | None = None ) -> list[dict[str, Any]]: parsed_date = _to_datetime(query_date).date() if query_date is not None else None return query_burst_location_schemes(name=network, network=network, query_date=parsed_date) def get_burst_location_scheme_detail(network: str, scheme_name: str) -> dict[str, Any]: result = query_burst_location_scheme_detail(network, scheme_name) if not result: raise ValueError(f"未找到爆管定位方案: {scheme_name}") return result def _store_burst_scheme( *, network: str, scheme_name: str, username: str, payload: dict[str, Any], burst_leakage: float, min_dpressure: float, basic_pressure: float, ) -> None: if scheme_name_exists(network, scheme_name): raise ValueError(f"方案名称已存在: {scheme_name}") now_iso = datetime.now().isoformat() scheme_detail = { "network": network, "pressure_scada_ids": payload.get("pressure_scada_ids", []), "flow_scada_ids": payload.get("flow_scada_ids", []), "observed_source": payload.get("observed_source"), "algorithm_params": { "burst_leakage": burst_leakage, "min_dpressure": min_dpressure, "basic_pressure": basic_pressure, }, "scada_window": payload.get("scada_window"), "result_summary": { "located_pipe": payload.get("located_pipe"), "simulation_times": payload.get("simulation_times"), "similarity_mode": payload.get("similarity_mode"), }, "result_payload": payload, } store_scheme_info( name=network, scheme_name=scheme_name, scheme_type="burst_location", username=username, scheme_start_time=now_iso, scheme_detail=scheme_detail, ) def _validate_scada_windows( *, scada_burst_start: datetime | str | None, scada_burst_end: datetime | str | None, scada_normal_start: datetime | str | None, scada_normal_end: datetime | str | None, ) -> tuple[datetime, datetime, datetime, datetime]: values = [scada_burst_start, scada_burst_end, scada_normal_start, scada_normal_end] if any(v is None for v in values): raise ValueError( "使用后端 SCADA 查询时,必须同时提供 scada_burst_start/scada_burst_end/scada_normal_start/scada_normal_end。" ) burst_start_dt = _to_datetime(scada_burst_start) burst_end_dt = _to_datetime(scada_burst_end) normal_start_dt = _to_datetime(scada_normal_start) normal_end_dt = _to_datetime(scada_normal_end) if burst_start_dt >= burst_end_dt: raise ValueError("爆管时段 SCADA 时间窗非法:scada_burst_start 必须早于 scada_burst_end。") if normal_start_dt >= normal_end_dt: raise ValueError( "正常时段 SCADA 时间窗非法:scada_normal_start 必须早于 scada_normal_end。" ) return burst_start_dt, burst_end_dt, normal_start_dt, normal_end_dt def _build_observed_series_from_scada( *, network: str, sensor_ids: list[str], start_dt: datetime, end_dt: datetime, data_type: str, series_name: str, ) -> tuple[pd.Series, int]: scada_mapping = _build_scada_mapping(network=network, data_type=data_type) missing_ids = [sensor_id for sensor_id in sensor_ids if sensor_id not in scada_mapping] if missing_ids: preview = ", ".join(missing_ids[:10]) raise ValueError(f"{series_name} 缺少可用 SCADA 映射: {preview}") query_ids = [scada_mapping[sensor_id] for sensor_id in sensor_ids] scada_data = InternalQueries.query_scada_by_ids_timerange( db_name=network, device_ids=query_ids, start_time=start_dt.isoformat(), end_time=end_dt.isoformat(), ) values: dict[str, float] = {} sample_counts: list[int] = [] for sensor_id, query_id in zip(sensor_ids, query_ids): records = scada_data.get(query_id, []) numeric_values = [ float(item["value"]) for item in records if item.get("value") is not None ] if not numeric_values: raise ValueError(f"{series_name} 在时间窗内无有效数据: {sensor_id}") values[sensor_id] = float(sum(numeric_values) / len(numeric_values)) sample_counts.append(len(numeric_values)) return pd.Series(values, dtype=float), min(sample_counts) def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]: mapping: dict[str, str] = {} for item in get_all_scada_info(network): scada_type = str(item.get("type", "")).lower() if data_type == "pressure": if scada_type != "pressure": continue elif data_type == "flow": if scada_type not in FLOW_SCADA_TYPES: continue else: raise ValueError(f"Unsupported data_type: {data_type}") node_id = item.get("associated_element_id") query_id = item.get("api_query_id") if ( isinstance(node_id, str) and node_id and isinstance(query_id, str) and query_id ): mapping[node_id] = query_id return mapping def _get_sensor_nodes(network: str, data_type: str) -> list[str]: mapping = _build_scada_mapping(network=network, data_type=data_type) sensor_ids = sorted(mapping.keys()) if not sensor_ids: type_name = "压力" if data_type == "pressure" else "流量" raise ValueError(f"未找到{type_name}传感器对应节点(scada_info.type)。") return sensor_ids def _dedupe_ids(ids: list[str] | None) -> list[str]: if ids is None: return [] return list(dict.fromkeys([str(item) for item in ids if item])) def _to_datetime(value: datetime | str) -> datetime: if isinstance(value, datetime): return value return datetime.fromisoformat(value) def _prepare_burst_inp(network: str) -> str: project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) db_inp_dir = os.path.join(project_root, "db_inp") os.makedirs(db_inp_dir, exist_ok=True) inp_path = os.path.join(db_inp_dir, f"{network}.burst.inp") if os.path.isfile(inp_path) and os.path.getsize(inp_path) > 0: return inp_path dump_inp(network, inp_path, "2") if not os.path.isfile(inp_path) or os.path.getsize(inp_path) <= 0: raise ValueError(f"爆管定位 INP 文件无效: {inp_path}") return inp_path