diff --git a/app/api/v1/endpoints/burst_location.py b/app/api/v1/endpoints/burst_location.py index b3fd6a5..e9ec3e4 100644 --- a/app/api/v1/endpoints/burst_location.py +++ b/app/api/v1/endpoints/burst_location.py @@ -30,8 +30,6 @@ class BurstLocationRequest(BaseModel): basic_pressure: float = 10.0 scada_burst_start: datetime | None = None scada_burst_end: datetime | None = None - scada_normal_start: datetime | None = None - scada_normal_end: datetime | None = None use_scada_flow: bool = False scheme_name: str | None = None simulation_scheme_name: str | None = None diff --git a/app/services/burst_location.py b/app/services/burst_location.py index d050ef0..cb8154e 100644 --- a/app/services/burst_location.py +++ b/app/services/burst_location.py @@ -46,6 +46,7 @@ def _normalize_series(data: SeriesInput, field_name: str) -> pd.Series: def run_burst_location_by_network( *, network: str, + username: str, data_source: str = "monitoring", burst_leakage: float, pressure_scada_ids: list[str] | None = None, @@ -58,13 +59,10 @@ def run_burst_location_by_network( basic_pressure: float = 10.0, scada_burst_start: datetime | str | None = None, scada_burst_end: datetime | str | None = None, - scada_normal_start: datetime | str | None = None, - scada_normal_end: datetime | str | None = None, use_scada_flow: bool = False, scheme_name: str | None = None, simulation_scheme_name: str | None = None, simulation_scheme_type: str | None = None, - username: str = "admin", ) -> dict[str, Any]: if not network: raise ValueError("network is required.") @@ -88,37 +86,29 @@ def run_burst_location_by_network( for value in [ scada_burst_start, scada_burst_end, - scada_normal_start, - scada_normal_end, ] ) if use_scada_pressure: - ( - burst_start_dt, - burst_end_dt, - normal_start_dt, - normal_end_dt, - ) = _validate_scada_windows( + burst_start_dt, burst_end_dt = _validate_scada_windows( scada_burst_start=scada_burst_start, scada_burst_end=scada_burst_end, - scada_normal_start=scada_normal_start, - scada_normal_end=scada_normal_end, ) if normalized_data_source == "simulation": if not simulation_scheme_name: raise ValueError("模拟方案模式必须提供 simulation_scheme_name。") - burst_pressure_series, burst_pressure_samples = ( - _build_observed_series_from_simulation( - network=network, - sensor_ids=selected_pressure_ids, - start_dt=burst_start_dt, - end_dt=burst_end_dt, - data_type="pressure", - series_name="burst_pressure", - simulation_source="scheme", - simulation_scheme_name=simulation_scheme_name, - simulation_scheme_type=resolved_simulation_scheme_type, - ) + ( + burst_pressure_series, + burst_pressure_samples, + ) = _build_observed_series_from_simulation( + network=network, + sensor_ids=selected_pressure_ids, + start_dt=burst_start_dt, + end_dt=burst_end_dt, + data_type="pressure", + series_name="burst_pressure", + simulation_source="scheme", + simulation_scheme_name=simulation_scheme_name, + simulation_scheme_type=resolved_simulation_scheme_type, ) ( normal_pressure_series, @@ -126,12 +116,12 @@ def run_burst_location_by_network( ) = _build_observed_series_from_simulation( network=network, sensor_ids=selected_pressure_ids, - start_dt=normal_start_dt, - end_dt=normal_end_dt, + start_dt=burst_start_dt, + end_dt=burst_end_dt, data_type="pressure", series_name="normal_pressure", simulation_source="realtime", - simulation_scheme_name=simulation_scheme_name, + simulation_scheme_name=None, simulation_scheme_type=resolved_simulation_scheme_type, ) observed_source = "simulation_scheme_burst_realtime_normal_timerange" @@ -149,15 +139,18 @@ def run_burst_location_by_network( ( normal_pressure_series, normal_pressure_samples, - ) = _build_observed_series_from_scada( + ) = _build_observed_series_from_simulation( network=network, sensor_ids=selected_pressure_ids, - start_dt=normal_start_dt, - end_dt=normal_end_dt, + start_dt=burst_start_dt, + end_dt=burst_end_dt, data_type="pressure", series_name="normal_pressure", + simulation_source="realtime", + simulation_scheme_name=None, + simulation_scheme_type=resolved_simulation_scheme_type, ) - observed_source = "backend_timerange" + observed_source = "scada_burst_realtime_normal_timerange" else: if burst_pressure is None or normal_pressure is None: raise ValueError( @@ -168,7 +161,7 @@ def run_burst_location_by_network( burst_pressure_samples = 1 normal_pressure_samples = 1 observed_source = "request_payload" - burst_start_dt = burst_end_dt = normal_start_dt = normal_end_dt = None + burst_start_dt = burst_end_dt = None selected_flow_ids: list[str] | None = None burst_flow_series: pd.Series | None = None @@ -204,12 +197,12 @@ def run_burst_location_by_network( _build_observed_series_from_simulation( network=network, sensor_ids=selected_flow_ids, - start_dt=normal_start_dt, - end_dt=normal_end_dt, + start_dt=burst_start_dt, + end_dt=burst_end_dt, data_type="flow", series_name="normal_flow", simulation_source="realtime", - simulation_scheme_name=simulation_scheme_name, + simulation_scheme_name=None, simulation_scheme_type=resolved_simulation_scheme_type, ) ) @@ -222,19 +215,26 @@ def run_burst_location_by_network( data_type="flow", series_name="burst_flow", ) - normal_flow_series, normal_flow_samples = _build_observed_series_from_scada( - network=network, - sensor_ids=selected_flow_ids, - start_dt=normal_start_dt, - end_dt=normal_end_dt, - data_type="flow", - series_name="normal_flow", + normal_flow_series, normal_flow_samples = ( + _build_observed_series_from_simulation( + network=network, + sensor_ids=selected_flow_ids, + start_dt=burst_start_dt, + end_dt=burst_end_dt, + data_type="flow", + series_name="normal_flow", + simulation_source="realtime", + simulation_scheme_name=None, + simulation_scheme_type=resolved_simulation_scheme_type, + ) ) else: if flow_scada_ids is not None: selected_flow_ids = _dedupe_ids(flow_scada_ids) burst_flow_series = ( - _normalize_series(burst_flow, "burst_flow") if burst_flow is not None else None + _normalize_series(burst_flow, "burst_flow") + if burst_flow is not None + else None ) normal_flow_series = ( _normalize_series(normal_flow, "normal_flow") @@ -278,8 +278,6 @@ def run_burst_location_by_network( payload["scada_window"] = { "burst_start": burst_start_dt.isoformat(), "burst_end": burst_end_dt.isoformat(), - "normal_start": normal_start_dt.isoformat(), - "normal_end": normal_end_dt.isoformat(), } if normalized_data_source == "simulation": payload["simulation_scheme"] = { @@ -303,7 +301,9 @@ def list_burst_location_schemes( network: str, query_date: datetime | str | None = None ) -> list[dict[str, Any]]: parsed_date = _to_datetime(query_date).date() if query_date is not None else None - return query_burst_location_schemes(name=network, network=network, query_date=parsed_date) + return query_burst_location_schemes( + name=network, network=network, query_date=parsed_date + ) def get_burst_location_scheme_detail(network: str, scheme_name: str) -> dict[str, Any]: @@ -359,25 +359,19 @@ def _validate_scada_windows( *, scada_burst_start: datetime | str | None, scada_burst_end: datetime | str | None, - scada_normal_start: datetime | str | None, - scada_normal_end: datetime | str | None, -) -> tuple[datetime, datetime, datetime, datetime]: - values = [scada_burst_start, scada_burst_end, scada_normal_start, scada_normal_end] +) -> tuple[datetime, datetime]: + values = [scada_burst_start, scada_burst_end] if any(v is None for v in values): raise ValueError( - "使用后端 SCADA 查询时,必须同时提供 scada_burst_start/scada_burst_end/scada_normal_start/scada_normal_end。" + "使用后端 SCADA 查询时,必须同时提供 scada_burst_start/scada_burst_end。" ) burst_start_dt = _to_datetime(scada_burst_start) burst_end_dt = _to_datetime(scada_burst_end) - normal_start_dt = _to_datetime(scada_normal_start) - normal_end_dt = _to_datetime(scada_normal_end) if burst_start_dt >= burst_end_dt: - raise ValueError("爆管时段 SCADA 时间窗非法:scada_burst_start 必须早于 scada_burst_end。") - if normal_start_dt >= normal_end_dt: raise ValueError( - "正常时段 SCADA 时间窗非法:scada_normal_start 必须早于 scada_normal_end。" + "爆管时段 SCADA 时间窗非法:scada_burst_start 必须早于 scada_burst_end。" ) - return burst_start_dt, burst_end_dt, normal_start_dt, normal_end_dt + return burst_start_dt, burst_end_dt def _build_observed_series_from_scada( @@ -390,7 +384,9 @@ def _build_observed_series_from_scada( series_name: str, ) -> tuple[pd.Series, int]: scada_mapping = _build_scada_mapping(network=network, data_type=data_type) - missing_ids = [sensor_id for sensor_id in sensor_ids if sensor_id not in scada_mapping] + missing_ids = [ + sensor_id for sensor_id in sensor_ids if sensor_id not in scada_mapping + ] if missing_ids: preview = ", ".join(missing_ids[:10]) raise ValueError(f"{series_name} 缺少可用 SCADA 映射: {preview}") @@ -407,9 +403,7 @@ def _build_observed_series_from_scada( for sensor_id, query_id in zip(sensor_ids, query_ids): records = scada_data.get(query_id, []) numeric_values = [ - float(item["value"]) - for item in records - if item.get("value") is not None + float(item["value"]) for item in records if item.get("value") is not None ] if not numeric_values: raise ValueError(f"{series_name} 在时间窗内无有效数据: {sensor_id}") @@ -432,7 +426,9 @@ def _build_observed_series_from_simulation( simulation_scheme_type: str, ) -> tuple[pd.Series, int]: sensor_metadata = _build_sensor_metadata(network=network, data_type=data_type) - missing_ids = [sensor_id for sensor_id in sensor_ids if sensor_id not in sensor_metadata] + missing_ids = [ + sensor_id for sensor_id in sensor_ids if sensor_id not in sensor_metadata + ] if missing_ids: preview = ", ".join(missing_ids[:10]) raise ValueError(f"{series_name} 缺少可用 SCADA 映射: {preview}") @@ -453,9 +449,7 @@ def _build_observed_series_from_simulation( for sensor_id in sensor_ids: records = simulation_data.get(sensor_id, []) numeric_values = [ - float(item["value"]) - for item in records - if item.get("value") is not None + float(item["value"]) for item in records if item.get("value") is not None ] if not numeric_values: raise ValueError(f"{series_name} 在时间窗内无有效模拟数据: {sensor_id}") @@ -480,7 +474,9 @@ def _query_simulation_data_by_sensor_ids( if simulation_source not in {"scheme", "realtime"}: raise ValueError(f"Unsupported simulation_source: {simulation_source}") - result: dict[str, list[dict[str, Any]]] = {sensor_id: [] for sensor_id in sensor_ids} + result: dict[str, list[dict[str, Any]]] = { + sensor_id: [] for sensor_id in sensor_ids + } if data_type == "pressure": result.update( _query_simulation_values( @@ -611,9 +607,7 @@ def _build_sensor_metadata(network: str, data_type: str) -> dict[str, dict[str, def _build_scada_mapping(network: str, data_type: str) -> dict[str, str]: metadata = _build_sensor_metadata(network=network, data_type=data_type) - return { - element_id: item["query_id"] for element_id, item in metadata.items() - } + return {element_id: item["query_id"] for element_id, item in metadata.items()} def _normalize_data_source( @@ -624,7 +618,9 @@ def _normalize_data_source( return "simulation" if simulation_scheme_name else "monitoring" if normalized not in SIMULATION_DATA_SOURCES: allowed_sources = ", ".join(sorted(SIMULATION_DATA_SOURCES)) - raise ValueError(f"Unsupported data_source: {data_source}. Allowed: {allowed_sources}") + raise ValueError( + f"Unsupported data_source: {data_source}. Allowed: {allowed_sources}" + ) return normalized diff --git a/app/services/leakage_identifier.py b/app/services/leakage_identifier.py index 359515f..38b9663 100644 --- a/app/services/leakage_identifier.py +++ b/app/services/leakage_identifier.py @@ -29,6 +29,7 @@ DEFAULT_N_WORKERS = max(1, min((os.cpu_count() or 1) - 1, 4)) def run_leakage_identification( network: str, + username: str, observed_pressure_data: ( str | pd.DataFrame | dict[str, list[Any]] | list[dict[str, Any]] | None ) = None, @@ -47,7 +48,6 @@ def run_leakage_identification( scada_end: datetime | str | None = None, sensor_nodes: list[str] | None = None, scheme_name: str | None = None, - username: str = "admin", ) -> dict[str, Any]: os.makedirs(output_dir, exist_ok=True) inp_path = _prepare_leakage_inp(network) diff --git a/tests/unit/test_burst_location_service.py b/tests/unit/test_burst_location_service.py index 508db0c..7b4fa08 100644 --- a/tests/unit/test_burst_location_service.py +++ b/tests/unit/test_burst_location_service.py @@ -74,7 +74,7 @@ def _load_burst_location_module(): return module -def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, tmp_path): +def test_run_burst_location_uses_single_timerange_with_burst_source_split(monkeypatch, tmp_path): module = _load_burst_location_module() captured = {} scheme_calls = [] @@ -162,14 +162,13 @@ def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, result = module.run_burst_location_by_network( network="tjwater", + username="testuser", data_source="simulation", simulation_scheme_name="BurstSchemeA", simulation_scheme_type="burst_analysis", burst_leakage=10.0, scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), - scada_normal_start=datetime(2025, 1, 1, 6, 0, 0), - scada_normal_end=datetime(2025, 1, 1, 7, 0, 0), use_scada_flow=True, ) @@ -193,9 +192,15 @@ def test_run_burst_location_uses_scheme_sources_for_full_timeranges(monkeypatch, assert any(call["element_type"] == "link" and call["field"] == "flow" for call in scheme_calls) assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in scheme_calls) assert len(realtime_calls) == 3 + assert all(datetime.fromisoformat(call["start_time"]).hour == 8 for call in realtime_calls) + assert all(datetime.fromisoformat(call["end_time"]).hour == 9 for call in realtime_calls) assert any(call["element_type"] == "node" and call["field"] == "pressure" for call in realtime_calls) assert any(call["element_type"] == "link" and call["field"] == "flow" for call in realtime_calls) assert any(call["element_type"] == "node" and call["field"] == "actual_demand" for call in realtime_calls) + assert result["scada_window"] == { + "burst_start": "2025-01-01T08:00:00", + "burst_end": "2025-01-01T09:00:00", + } def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_path): @@ -217,10 +222,9 @@ def test_run_burst_location_requires_simulation_scheme_name(monkeypatch, tmp_pat with pytest.raises(ValueError, match="simulation_scheme_name"): module.run_burst_location_by_network( network="tjwater", + username="testuser", data_source="simulation", burst_leakage=1.0, scada_burst_start=datetime(2025, 1, 1, 8, 0, 0), scada_burst_end=datetime(2025, 1, 1, 9, 0, 0), - scada_normal_start=datetime(2025, 1, 1, 6, 0, 0), - scada_normal_end=datetime(2025, 1, 1, 7, 0, 0), )