diff --git a/app/algorithms/burst_detection/__init__.py b/app/algorithms/burst_detection/__init__.py new file mode 100644 index 0000000..226dc73 --- /dev/null +++ b/app/algorithms/burst_detection/__init__.py @@ -0,0 +1,3 @@ +from app.algorithms.burst_detection.burst_detector import BurstDetector + +__all__ = ["BurstDetector"] diff --git a/app/algorithms/burst_detection/burst_detector.py b/app/algorithms/burst_detection/burst_detector.py new file mode 100644 index 0000000..cf5f43a --- /dev/null +++ b/app/algorithms/burst_detection/burst_detector.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pandas as pd +from scipy.fft import fft, ifft +from sklearn.ensemble import IsolationForest + +PressureDataInput = ( + pd.DataFrame + | dict[str, list[Any]] + | list[dict[str, Any]] + | list[list[Any]] + | np.ndarray +) +IGNORED_OBSERVATION_COLUMNS = {"time", "timestamp", "datetime", "date"} + + +class BurstDetector: + """FFT + IsolationForest based burst detection for daily aligned pressure data.""" + + def __init__( + self, + *, + mu: int = 100, + points_per_day: int = 1440, + iforest_params: dict[str, Any] | None = None, + ) -> None: + if points_per_day <= 0: + raise ValueError("points_per_day 必须大于 0。") + if mu <= 0: + raise ValueError("mu 必须大于 0。") + + self.mu = int(mu) + self.points_per_day = int(points_per_day) + self.iforest_params = { + "n_estimators": 50, + "random_state": 42, + "contamination": "auto", + } + if iforest_params: + self.iforest_params.update(iforest_params) + + self.data: np.ndarray | None = None + self.sensor_names: list[str] = [] + self.high_freq_features: np.ndarray | None = None + + def load_data( + self, + data_source: PressureDataInput, + *, + sensor_nodes: list[str] | None = None, + ) -> pd.DataFrame: + """ + 标准化输入观测数据为 DataFrame。 + + 支持的 `data_source` 格式: + - `pd.DataFrame` + 每一列代表一个传感器,每一行代表一个时间点。 + - `dict[str, list[Any]]` + 键为传感器 ID,值为该传感器按时间顺序排列的压力序列。 + 例如:`{"J1": [101.2, 101.0], "J2": [99.8, 99.7]}`。 + - `list[dict[str, Any]]` + 每个字典代表一个时间点,键为传感器 ID,值为该时刻压力。 + 例如:`[{"J1": 101.2, "J2": 99.8}, {"J1": 101.0, "J2": 99.7}]`。 + - `list[list[Any]]` + 二维列表,格式为 `(时间点数, 传感器数)`。 + 例如:`[[101.2, 99.8], [101.0, 99.7]]`。 + - `np.ndarray` + 二维数组,形状必须为 `(时间点数, 传感器数)`。 + + 参数: + - `sensor_nodes`: + 可选的传感器列筛选列表。传入后,数据中必须包含这些列名。 + + 返回: + - 标准化后的 `pd.DataFrame`,列为传感器,行为时间点。 + """ + if isinstance(data_source, np.ndarray): + observation_df = pd.DataFrame(data_source) + elif isinstance(data_source, pd.DataFrame): + observation_df = data_source.copy() + else: + observation_df = pd.DataFrame(data_source) + + return self._normalize_observation_frame( + observation_df=observation_df, sensor_nodes=sensor_nodes + ) + + def process( + self, + observed_pressure_data: PressureDataInput, + *, + sensor_nodes: list[str] | None = None, + ) -> np.ndarray: + """ + 对输入压力序列按天切片,并提取每天末时刻的高频特征。 + + `observed_pressure_data` 的格式与 `load_data()` 一致,统一要求: + - 数据必须表示为“行=时间点、列=传感器”。 + - 总行数必须是 `points_per_day` 的整数倍。 + - 至少需要 2 天数据,即总行数 `>= 2 * points_per_day`。 + + 例如: + - 当 `points_per_day=1440` 时,15 天数据的形状通常为 `(21600, 传感器数)`。 + - 若传入 `sensor_nodes=["J1", "J2"]`,则输入中必须存在 `J1/J2` 两列。 + + 返回: + - `np.ndarray`,形状为 `(天数, 传感器数)`, + 每个值表示对应传感器在当天末时刻提取出的高频分量。 + """ + observation_df = self.load_data( + observed_pressure_data, + sensor_nodes=sensor_nodes, + ) + matrix = observation_df.to_numpy(dtype=float) + total_points, sensor_count = matrix.shape + if sensor_count == 0: + raise ValueError("压力观测数据中未找到可用传感器列。") + if total_points < self.points_per_day * 2: + raise ValueError("至少需要 2 天的观测数据才能执行爆管侦测。") + if total_points % self.points_per_day != 0: + raise ValueError("观测数据长度必须能被 points_per_day 整除,以便按天切分。") + + day_count = total_points // self.points_per_day + high_freq_features = np.zeros((day_count, sensor_count), dtype=float) + + for sensor_idx in range(sensor_count): + sensor_series = matrix[:, sensor_idx] + for day_idx in range(day_count): + start = day_idx * self.points_per_day + end = (day_idx + 1) * self.points_per_day + day_data = sensor_series[start:end] + mirrored_data = np.concatenate([day_data, day_data[::-1]]) + transformed = fft(mirrored_data) + transformed[self.mu : len(mirrored_data) - self.mu + 1] = 0 + low_freq = ifft(transformed).real + high_freq = day_data - low_freq[: self.points_per_day] + high_freq_features[day_idx, sensor_idx] = float(high_freq[-1]) + + self.data = matrix + self.sensor_names = [str(column) for column in observation_df.columns] + self.high_freq_features = high_freq_features + return high_freq_features + + def detect(self) -> pd.DataFrame: + if self.high_freq_features is None: + raise ValueError("特征未提取。请先调用 process()。") + + day_count = self.high_freq_features.shape[0] + if day_count < 2: + raise ValueError("孤立森林至少需要 2 天特征数据。") + + clf = IsolationForest( + n_estimators=self.iforest_params.get("n_estimators", 50), + max_samples=day_count, + random_state=self.iforest_params.get("random_state", 42), + contamination=self.iforest_params.get("contamination", "auto"), + **{ + key: value + for key, value in self.iforest_params.items() + if key not in {"n_estimators", "random_state", "contamination"} + }, + ) + clf.fit(self.high_freq_features) + + scores = clf.decision_function(self.high_freq_features) + predictions = clf.predict(self.high_freq_features) + result_df = pd.DataFrame( + { + "Day": range(1, day_count + 1), + "Score": scores.astype(float), + "Prediction": predictions.astype(int), + } + ) + result_df["IsBurst"] = result_df["Prediction"].eq(-1) + result_df.attrs["sensor_nodes"] = self.sensor_names.copy() + result_df.attrs["high_freq_features"] = self.high_freq_features.copy() + result_df.attrs["day_count"] = day_count + result_df.attrs["points_per_day"] = self.points_per_day + result_df.attrs["sample_count"] = ( + int(self.data.shape[0]) if self.data is not None else 0 + ) + return result_df + + def run_detection( + self, + observed_pressure_data: PressureDataInput, + *, + sensor_nodes: list[str] | None = None, + ) -> pd.DataFrame: + """ + 执行完整爆管侦测流程。 + + 输入格式与 `process()` 相同: + - `DataFrame` / `dict[str, list[Any]]` / `list[dict[str, Any]]` / `list[list[Any]]` / `np.ndarray` + - 行表示时间点,列表示传感器 + - 总行数必须能被 `points_per_day` 整除 + + 返回结果包含列: + - `Day`: 第几天(从 1 开始) + - `Score`: IsolationForest 异常分数,越小越异常 + - `Prediction`: `-1` 表示异常,`1` 表示正常 + - `IsBurst`: 是否判定为异常日 + """ + self.process(observed_pressure_data, sensor_nodes=sensor_nodes) + return self.detect() + + @staticmethod + def _normalize_observation_frame( + *, + observation_df: pd.DataFrame, + sensor_nodes: list[str] | None, + ) -> pd.DataFrame: + if observation_df.empty: + raise ValueError("压力观测数据为空。") + + normalized_df = observation_df.copy() + normalized_df.columns = [str(column) for column in normalized_df.columns] + normalized_df = normalized_df.drop( + columns=[ + column + for column in normalized_df.columns + if column.lower() in IGNORED_OBSERVATION_COLUMNS + or column.lower().startswith("unnamed:") + ], + errors="ignore", + ) + + if sensor_nodes: + selected_columns = [str(node) for node in sensor_nodes] + missing_columns = [ + column + for column in selected_columns + if column not in normalized_df.columns + ] + if missing_columns: + preview = ", ".join(missing_columns[:10]) + raise ValueError(f"观测数据缺少传感器列: {preview}") + normalized_df = normalized_df.loc[:, selected_columns] + else: + candidate_df = normalized_df.apply(pd.to_numeric, errors="coerce") + normalized_df = candidate_df.loc[:, candidate_df.notna().any(axis=0)] + + if normalized_df.empty: + raise ValueError("未识别到可用的数值型压力观测列。") + + normalized_df = normalized_df.apply(pd.to_numeric, errors="coerce") + invalid_columns = [ + column + for column in normalized_df.columns + if normalized_df[column].isna().any() + ] + if invalid_columns: + preview = ", ".join(invalid_columns[:10]) + raise ValueError(f"压力观测数据包含非数值或缺失值: {preview}") + + return normalized_df.reset_index(drop=True) diff --git a/app/api/v1/endpoints/burst_detection.py b/app/api/v1/endpoints/burst_detection.py new file mode 100644 index 0000000..b102e08 --- /dev/null +++ b/app/api/v1/endpoints/burst_detection.py @@ -0,0 +1,68 @@ +from datetime import datetime +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +from app.auth.keycloak_dependencies import get_current_keycloak_username +from app.services.burst_detection import ( + get_burst_detection_scheme_detail, + list_burst_detection_schemes, + run_burst_detection, +) + +router = APIRouter() + + +class BurstDetectionRequest(BaseModel): + network: str + observed_pressure_data: ( + dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]] | None + ) = Field( + default=None, + description=( + "压力观测数据。支持列式字典 {sensor_id: [values,...]}、" + "逐时刻对象数组 [{sensor_id: value,...}, ...]、" + "或二维数组 [[t1_s1, t1_s2], [t2_s1, t2_s2], ...]。" + ), + ) + points_per_day: int = 1440 + mu: int = 100 + iforest_params: dict[str, Any] | None = None + scada_start: datetime | None = None + scada_end: datetime | None = None + sensor_nodes: list[str] | None = None + scheme_name: str | None = None + + +@router.post("/detect/") +async def detect_burst( + data: BurstDetectionRequest, + username: str = Depends(get_current_keycloak_username), +) -> dict[str, Any]: + try: + return run_burst_detection(**data.model_dump(), username=username) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.get("/schemes/") +async def query_burst_detection_schemes( + network: str, + query_date: datetime | None = None, +) -> list[dict[str, Any]]: + try: + return list_burst_detection_schemes(network=network, query_date=query_date) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.get("/schemes/{scheme_name}") +async def query_burst_detection_scheme_detail( + network: str, + scheme_name: str, +) -> dict[str, Any]: + try: + return get_burst_detection_scheme_detail(network=network, scheme_name=scheme_name) + except Exception as exc: + raise HTTPException(status_code=400, detail=str(exc)) diff --git a/app/api/v1/router.py b/app/api/v1/router.py index 2768142..53a6125 100644 --- a/app/api/v1/router.py +++ b/app/api/v1/router.py @@ -13,6 +13,7 @@ from app.api.v1.endpoints import ( risk, cache, leakage, + burst_detection, burst_location, user_management, # 新增:用户管理 audit, # 新增:审计日志 @@ -91,6 +92,9 @@ api_router.include_router(misc.router, tags=["Misc"]) api_router.include_router(risk.router, tags=["Risk"]) api_router.include_router(cache.router, tags=["Cache"]) api_router.include_router(leakage.router, prefix="/leakage", tags=["Leakage"]) +api_router.include_router( + burst_detection.router, prefix="/burst-detection", tags=["Burst Detection"] +) api_router.include_router( burst_location.router, prefix="/burst-location", tags=["Burst Location"] ) diff --git a/app/services/burst_detection.py b/app/services/burst_detection.py new file mode 100644 index 0000000..bdd297a --- /dev/null +++ b/app/services/burst_detection.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +import pandas as pd + +from app.algorithms.burst_detection.burst_detector import BurstDetector +from app.infra.db.timescaledb.internal_queries import InternalQueries +from app.services.scheme_management import ( + query_burst_detection_scheme_detail, + query_burst_detection_schemes, + scheme_name_exists, + store_scheme_info, +) +from app.services.tjnetwork import get_all_scada_info + + +def run_burst_detection( + *, + network: str, + username: str, + observed_pressure_data: ( + pd.DataFrame + | dict[str, list[Any]] + | list[dict[str, Any]] + | list[list[Any]] + | None + ) = None, + points_per_day: int = 1440, + mu: int = 100, + iforest_params: dict[str, Any] | None = None, + scada_start: datetime | str | None = None, + scada_end: datetime | str | None = None, + sensor_nodes: list[str] | None = None, + scheme_name: str | None = None, +) -> dict[str, Any]: + """ + 运行爆管侦测服务入口。 + + 调用方式二选一: + - 直接传 `observed_pressure_data` + - 或传 `scada_start/scada_end` 让后端自动查询 SCADA 压力数据 + + `observed_pressure_data` 支持格式: + - `pd.DataFrame` + 行表示时间点,列表示传感器;列名应为传感器/节点 ID。 + - `dict[str, list[Any]]` + 键为传感器/节点 ID,值为按时间顺序排列的压力序列。 + 例如:`{"J1": [101.2, 101.0], "J2": [99.8, 99.7]}`。 + - `list[dict[str, Any]]` + 每个元素代表一个时间点的多传感器观测。 + 例如:`[{"J1": 101.2, "J2": 99.8}, {"J1": 101.0, "J2": 99.7}]`。 + - `list[list[Any]]` + 二维数组式 JSON,格式为 `(时间点数, 传感器数)`。 + 这是最接近原始 `burst_detector` 示例代码的调用方式。 + + 数据约束: + - 统一要求“行=时间点,列=传感器”。 + - 总样本点数必须能被 `points_per_day` 整除。 + - 至少要有 2 天数据,即 `sample_count >= 2 * points_per_day`。 + - 若传入 `sensor_nodes`,输入数据必须包含这些列;SCADA 模式下也会只按这些节点取数。 + """ + if not network: + raise ValueError("network is required.") + + selected_sensor_nodes = ( + list(dict.fromkeys([node for node in (sensor_nodes or []) if node])) + if sensor_nodes + else None + ) + use_scada_source = scada_start is not None or scada_end is not None + + if use_scada_source: + scada_sensor_nodes = ( + selected_sensor_nodes + if selected_sensor_nodes is not None + else _get_pressure_sensor_nodes(network) + ) + observed_df = _build_observed_pressure_from_scada( + network=network, + sensor_nodes=scada_sensor_nodes, + scada_start=scada_start, + scada_end=scada_end, + ) + observed_input: pd.DataFrame | dict[str, list[Any]] | list[dict[str, Any]] | list[list[Any]] = observed_df + observed_source = "backend_timerange" + else: + if observed_pressure_data is None: + raise ValueError( + "未提供 observed_pressure_data,且未提供 scada_start/scada_end。" + ) + observed_input = observed_pressure_data + observed_source = "request_payload" + + detector = BurstDetector( + mu=mu, + points_per_day=points_per_day, + iforest_params=iforest_params, + ) + result_df = detector.run_detection( + observed_input, + sensor_nodes=selected_sensor_nodes, + ) + resolved_sensor_nodes = list(result_df.attrs.get("sensor_nodes", [])) + rows = _serialize_result_rows(result_df) + payload: dict[str, Any] = { + "network": network, + "sensor_nodes": resolved_sensor_nodes, + "observed_source": observed_source, + "sample_count": int(result_df.attrs.get("sample_count", 0)), + "points_per_day": int(result_df.attrs.get("points_per_day", points_per_day)), + "day_count": int(result_df.attrs.get("day_count", len(result_df))), + "rows": rows, + "summary": _build_detection_summary(result_df), + } + if use_scada_source: + payload["scada_window"] = { + "start": _to_datetime(scada_start).isoformat(), + "end": _to_datetime(scada_end).isoformat(), + } + if scheme_name: + _store_burst_detection_scheme( + network=network, + scheme_name=scheme_name, + username=username, + payload=payload, + mu=mu, + points_per_day=points_per_day, + iforest_params=detector.iforest_params, + ) + payload["scheme_name"] = scheme_name + return payload + + +def list_burst_detection_schemes( + network: str, + query_date: datetime | str | None = None, +) -> list[dict[str, Any]]: + parsed_date = _to_datetime(query_date).date() if query_date is not None else None + return query_burst_detection_schemes( + name=network, + network=network, + query_date=parsed_date, + ) + + +def get_burst_detection_scheme_detail(network: str, scheme_name: str) -> dict[str, Any]: + result = query_burst_detection_scheme_detail(network, scheme_name) + if not result: + raise ValueError(f"未找到爆管侦测方案: {scheme_name}") + return result + + +def _store_burst_detection_scheme( + *, + network: str, + scheme_name: str, + username: str, + payload: dict[str, Any], + mu: int, + points_per_day: int, + iforest_params: dict[str, Any], +) -> None: + if scheme_name_exists(network, scheme_name): + raise ValueError(f"方案名称已存在: {scheme_name}") + + now_iso = datetime.now().isoformat() + scheme_detail = { + "network": network, + "sensor_nodes": payload.get("sensor_nodes", []), + "observed_source": payload.get("observed_source"), + "scada_window": payload.get("scada_window"), + "algorithm_params": { + "mu": mu, + "points_per_day": points_per_day, + "iforest_params": iforest_params, + }, + "result_summary": payload.get("summary", {}), + "result_payload": payload, + } + store_scheme_info( + name=network, + scheme_name=scheme_name, + scheme_type="burst_detection", + username=username, + scheme_start_time=now_iso, + scheme_detail=scheme_detail, + ) + + +def _serialize_result_rows(result_df: pd.DataFrame) -> list[dict[str, Any]]: + rows: list[dict[str, Any]] = [] + for row in result_df.to_dict(orient="records"): + rows.append( + { + "Day": int(row["Day"]), + "Score": float(row["Score"]), + "Prediction": int(row["Prediction"]), + "IsBurst": bool(row["IsBurst"]), + } + ) + return rows + + +def _build_detection_summary(result_df: pd.DataFrame) -> dict[str, Any]: + rows = _serialize_result_rows(result_df) + if not rows: + raise ValueError("爆管侦测结果为空。") + + score_series = result_df["Score"] + most_anomalous_index = int(score_series.idxmin()) + latest_row = rows[-1] + anomaly_days = [row["Day"] for row in rows if row["IsBurst"]] + + return { + "burst_detected": bool(latest_row["IsBurst"]), + "latest_day": latest_row, + "most_anomalous_day": int(result_df.iloc[most_anomalous_index]["Day"]), + "anomaly_days": anomaly_days, + "anomaly_day_count": len(anomaly_days), + "latest_sensor_rankings": _build_latest_sensor_rankings(result_df), + } + + +def _build_latest_sensor_rankings(result_df: pd.DataFrame) -> list[dict[str, Any]]: + feature_matrix = result_df.attrs.get("high_freq_features") + sensor_nodes = list(result_df.attrs.get("sensor_nodes", [])) + if feature_matrix is None or len(sensor_nodes) == 0: + return [] + + latest_values = feature_matrix[-1] + ranking = sorted( + zip(sensor_nodes, latest_values, strict=False), + key=lambda item: item[1], + ) + return [ + { + "sensor_node": sensor_id, + "latest_high_frequency_value": float(value), + } + for sensor_id, value in ranking[: min(10, len(ranking))] + ] + + +def _get_pressure_sensor_nodes(network: str) -> list[str]: + sensor_nodes: list[str] = [] + for item in get_all_scada_info(network): + if str(item.get("type", "")).lower() != "pressure": + continue + node_id = item.get("associated_element_id") + if isinstance(node_id, str) and node_id: + sensor_nodes.append(node_id) + sensor_nodes = list(dict.fromkeys(sensor_nodes)) + if not sensor_nodes: + raise ValueError("未找到压力传感器对应节点(scada_info.type=pressure)。") + return sensor_nodes + + +def _build_observed_pressure_from_scada( + *, + network: str, + sensor_nodes: list[str], + scada_start: datetime | str | None, + scada_end: datetime | str | None, +) -> pd.DataFrame: + if scada_start is None or scada_end is None: + raise ValueError("使用后端 SCADA 查询时必须同时提供 scada_start 与 scada_end。") + + start_dt = _to_datetime(scada_start) + end_dt = _to_datetime(scada_end) + if start_dt >= end_dt: + raise ValueError("SCADA 时间窗非法:scada_start 必须早于 scada_end。") + + node_query_id: dict[str, str] = {} + for item in get_all_scada_info(network): + if str(item.get("type", "")).lower() != "pressure": + continue + node_id = item.get("associated_element_id") + query_id = item.get("api_query_id") + if ( + isinstance(node_id, str) + and node_id + and isinstance(query_id, str) + and query_id + ): + node_query_id[node_id] = query_id + + missing_nodes = [node_id for node_id in sensor_nodes if node_id not in node_query_id] + if missing_nodes: + preview = ", ".join(missing_nodes[:10]) + raise ValueError(f"未找到可用于压力观测的 SCADA api_query_id: {preview}") + + query_ids = [node_query_id[node_id] for node_id in sensor_nodes] + scada_data = InternalQueries.query_scada_by_ids_timerange( + db_name=network, + device_ids=query_ids, + start_time=start_dt.isoformat(), + end_time=end_dt.isoformat(), + ) + + available_lengths = [ + len(scada_data.get(query_id, [])) + for query_id in query_ids + if len(scada_data.get(query_id, [])) > 0 + ] + if not available_lengths: + raise ValueError("指定时间窗内未查询到压力 SCADA 数据。") + + min_len = min(available_lengths) + observation_df = pd.DataFrame() + for node_id in sensor_nodes: + query_id = node_query_id[node_id] + records = scada_data.get(query_id, [])[:min_len] + if len(records) < min_len: + continue + observation_df[node_id] = [float(item["value"]) for item in records] + + if observation_df.empty: + raise ValueError("SCADA 压力数据无法构建观测矩阵。") + return observation_df + + +def _to_datetime(value: datetime | str) -> datetime: + if isinstance(value, datetime): + return value + return datetime.fromisoformat(value) diff --git a/app/services/scheme_management.py b/app/services/scheme_management.py index 80fc3b3..a86a9bd 100644 --- a/app/services/scheme_management.py +++ b/app/services/scheme_management.py @@ -401,6 +401,85 @@ def query_burst_location_scheme_detail(name: str, scheme_name: str) -> dict: } +def query_burst_detection_schemes( + name: str, + network: str, + scheme_type: str = "burst_detection", + query_date: date | None = None, +) -> list[dict]: + conn_string = get_pgconn_string(db_name=name) + with psycopg.connect(conn_string) as conn: + with conn.cursor() as cur: + if query_date is None: + cur.execute( + """ + SELECT scheme_id, scheme_name, scheme_type, username, create_time, scheme_start_time, scheme_detail + FROM public.scheme_list + WHERE scheme_type = %s + ORDER BY create_time DESC + """, + (scheme_type,), + ) + else: + cur.execute( + """ + SELECT scheme_id, scheme_name, scheme_type, username, create_time, scheme_start_time, scheme_detail + FROM public.scheme_list + WHERE scheme_type = %s AND DATE(create_time) = %s + ORDER BY create_time DESC + """, + (scheme_type, query_date), + ) + rows = cur.fetchall() + result = [] + for row in rows: + detail = row[6] if isinstance(row[6], dict) else {} + if network and detail.get("network") not in (None, network): + continue + result.append( + { + "scheme_id": row[0], + "scheme_name": row[1], + "scheme_type": row[2], + "username": row[3], + "create_time": row[4], + "scheme_start_time": row[5], + "scheme_detail": detail, + } + ) + return result + + +def query_burst_detection_scheme_detail(name: str, scheme_name: str) -> dict: + conn_string = get_pgconn_string(db_name=name) + with psycopg.connect(conn_string) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT scheme_id, scheme_name, scheme_type, username, create_time, scheme_start_time, scheme_detail + FROM public.scheme_list + WHERE scheme_name = %s + LIMIT 1 + """, + (scheme_name,), + ) + base_row = cur.fetchone() + if base_row is None: + return {} + detail = base_row[6] if isinstance(base_row[6], dict) else {} + return { + "scheme_id": base_row[0], + "scheme_name": base_row[1], + "scheme_type": base_row[2], + "username": base_row[3], + "create_time": base_row[4], + "scheme_start_time": base_row[5], + "scheme_detail": detail, + "network": detail.get("network"), + "result_payload": detail.get("result_payload", {}), + } + + # 2025/03/23 def upload_shp_to_pg(name: str, table_name: str, role: str, shp_file_path: str): """ diff --git a/requirements.txt b/requirements.txt index 9391bd3..9881257 100644 --- a/requirements.txt +++ b/requirements.txt @@ -165,4 +165,6 @@ wntr==1.3.2 wrapt==1.17.3 zipp==3.23.0 zmq==0.0.0 -pymoo==0.6.1.6 \ No newline at end of file +pymoo==0.6.1.6 +scikit-learn +scipy \ No newline at end of file