diff --git a/timescaledb/composite_queries.py b/timescaledb/composite_queries.py index 0ebc4eb..423118f 100644 --- a/timescaledb/composite_queries.py +++ b/timescaledb/composite_queries.py @@ -1,6 +1,8 @@ -from typing import List, Optional, Dict, Any +from typing import List, Optional, Any from datetime import datetime from psycopg import AsyncConnection +import pandas as pd +import api_ex from postgresql.scada_info import ScadaRepository as PostgreScadaRepository from timescaledb.schemas.realtime import RealtimeRepository @@ -204,3 +206,127 @@ class CompositeQueries: return await ScadaRepository.get_scada_field_by_id_time_range( timescale_conn, device_id, start_time, end_time, data_field ) + + @staticmethod + async def clean_scada_data( + timescale_conn: AsyncConnection, + postgres_conn: AsyncConnection, + device_ids: List[str], + start_time: datetime, + end_time: datetime, + ) -> str: + """ + 清洗 SCADA 数据 + + 根据 device_ids 查询 monitored_value,清洗后更新 cleaned_value + + Args: + timescale_conn: TimescaleDB 连接 + postgres_conn: PostgreSQL 连接 + device_ids: 设备 ID 列表 + start_time: 开始时间 + end_time: 结束时间 + + Returns: + "success" 或错误信息 + """ + try: + # 获取所有 SCADA 信息 + scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn) + # 将列表转换为字典,以 device_id 为键 + scada_device_info_dict = {info["id"]: info for info in scada_infos} + + # 按设备类型分组设备 + type_groups = {} + for device_id in device_ids: + device_info = scada_device_info_dict.get(device_id, {}) + device_type = device_info.get("type", "unknown") + if device_type not in type_groups: + type_groups[device_type] = [] + type_groups[device_type].append(device_id) + + # 批量处理每种类型的设备 + for device_type, ids in type_groups.items(): + if device_type not in ["pressure", "pipe_flow"]: + continue # 跳过未知类型 + + # 查询 monitored_value 数据 + data = await ScadaRepository.get_scada_field_by_id_time_range( + timescale_conn, ids, start_time, end_time, "monitored_value" + ) + + if not data: + continue + + # 将嵌套字典转换为 DataFrame,使用 time 作为索引 + # data 格式: {device_id: [{"time": "...", "value": ...}, ...]} + all_records = [] + for device_id, records in data.items(): + for record in records: + all_records.append( + { + "time": record["time"], + "device_id": device_id, + "value": record["value"], + } + ) + + if not all_records: + continue + + # 创建 DataFrame 并透视,使 device_id 成为列 + df_long = pd.DataFrame(all_records) + df = df_long.pivot(index="time", columns="device_id", values="value") + + # 确保所有请求的设备都在列中(即使没有数据) + for device_id in ids: + if device_id not in df.columns: + df[device_id] = None + + # 只保留请求的设备列 + df = df[ids] + + # 重置索引,将 time 变为普通列 + df = df.reset_index() + + # 移除 time 列,准备输入给清洗方法 + value_df = df.drop(columns=["time"]) + + # 调用清洗方法 + if device_type == "pressure": + cleaned_dict = api_ex.Pdataclean.clean_pressure_data_dict_km( + value_df.to_dict(orient="list") + ) + elif device_type == "pipe_flow": + cleaned_dict = api_ex.Fdataclean.clean_flow_data_dict( + value_df.to_dict(orient="list") + ) + else: + continue + + # 将字典转换为 DataFrame(字典键为设备ID,值为值列表) + cleaned_value_df = pd.DataFrame(cleaned_dict) + + # 添加 time 列到首列 + cleaned_df = pd.concat([df["time"], cleaned_value_df], axis=1) + + # 将清洗后的数据写回数据库 + for device_id in ids: + if device_id in cleaned_df.columns: + cleaned_values = cleaned_df[device_id].tolist() + time_values = cleaned_df["time"].tolist() + for i, time_str in enumerate(time_values): + # time_str 已经是 ISO 格式字符串 + time_dt = datetime.fromisoformat(time_str) + value = cleaned_values[i] + await ScadaRepository.update_scada_field( + timescale_conn, + time_dt, + device_id, + "cleaned_value", + value, + ) + + return "success" + except Exception as e: + return f"error: {str(e)}" diff --git a/timescaledb/router.py b/timescaledb/router.py index 93cc813..7576b78 100644 --- a/timescaledb/router.py +++ b/timescaledb/router.py @@ -505,3 +505,29 @@ async def get_element_associated_scada_data( return result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + + +@router.post("/composite/clean-scada") +async def clean_scada_data( + device_ids: str, + start_time: datetime = Query(...), + end_time: datetime = Query(...), + timescale_conn: AsyncConnection = Depends(get_database_connection), + postgres_conn: AsyncConnection = Depends(get_postgres_connection), +): + """ + 清洗 SCADA 数据 + + 根据 device_ids 查询 monitored_value,清洗后更新 cleaned_value + """ + try: + device_ids_list = ( + [id.strip() for id in device_ids.split(",") if id.strip()] + if device_ids + else [] + ) + return await CompositeQueries.clean_scada_data( + timescale_conn, postgres_conn, device_ids_list, start_time, end_time + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/timescaledb/schemas/realtime.py b/timescaledb/schemas/realtime.py index 2cc00f8..00a2e28 100644 --- a/timescaledb/schemas/realtime.py +++ b/timescaledb/schemas/realtime.py @@ -1,5 +1,6 @@ from typing import List, Any, Dict from datetime import datetime, timedelta, timezone +from collections import defaultdict from psycopg import AsyncConnection, Connection, sql # 定义UTC+8时区 @@ -15,7 +16,7 @@ class RealtimeRepository: """Batch insert for realtime.link_simulation using DELETE then COPY for performance.""" if not data: return - + # 假设同一批次的数据时间是相同的 target_time = data[0]["time"] @@ -25,7 +26,7 @@ class RealtimeRepository: # 1. 先删除该时间点的旧数据 await cur.execute( "DELETE FROM realtime.link_simulation WHERE time = %s", - (target_time,) + (target_time,), ) # 2. 使用 COPY 快速写入新数据 @@ -33,25 +34,27 @@ class RealtimeRepository: "COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN" ) as copy: for item in data: - await copy.write_row(( - item["time"], - item["id"], - item.get("flow"), - item.get("friction"), - item.get("headloss"), - item.get("quality"), - item.get("reaction"), - item.get("setting"), - item.get("status"), - item.get("velocity"), - )) + await copy.write_row( + ( + item["time"], + item["id"], + item.get("flow"), + item.get("friction"), + item.get("headloss"), + item.get("quality"), + item.get("reaction"), + item.get("setting"), + item.get("status"), + item.get("velocity"), + ) + ) @staticmethod def insert_links_batch_sync(conn: Connection, data: List[dict]): """Batch insert for realtime.link_simulation using DELETE then COPY for performance (sync version).""" if not data: return - + # 假设同一批次的数据时间是相同的 target_time = data[0]["time"] @@ -61,7 +64,7 @@ class RealtimeRepository: # 1. 先删除该时间点的旧数据 cur.execute( "DELETE FROM realtime.link_simulation WHERE time = %s", - (target_time,) + (target_time,), ) # 2. 使用 COPY 快速写入新数据 @@ -69,18 +72,20 @@ class RealtimeRepository: "COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN" ) as copy: for item in data: - copy.write_row(( - item["time"], - item["id"], - item.get("flow"), - item.get("friction"), - item.get("headloss"), - item.get("quality"), - item.get("reaction"), - item.get("setting"), - item.get("status"), - item.get("velocity"), - )) + copy.write_row( + ( + item["time"], + item["id"], + item.get("flow"), + item.get("friction"), + item.get("headloss"), + item.get("quality"), + item.get("reaction"), + item.get("setting"), + item.get("status"), + item.get("velocity"), + ) + ) @staticmethod async def get_link_by_time_range( @@ -111,7 +116,7 @@ class RealtimeRepository: end_time: datetime, link_id: str, field: str, - ) -> Any: + ) -> List[Dict[str, Any]]: # Validate field name to prevent SQL injection valid_fields = { "flow", @@ -127,13 +132,15 @@ class RealtimeRepository: raise ValueError(f"Invalid field: {field}") query = sql.SQL( - "SELECT {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s" + "SELECT time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (start_time, end_time, link_id)) - row = await cur.fetchone() - return row[field] if row else None + rows = await cur.fetchall() + return [ + {"time": row["time"].isoformat(), "value": row[field]} for row in rows + ] @staticmethod async def get_links_field_by_time_range( @@ -141,7 +148,7 @@ class RealtimeRepository: start_time: datetime, end_time: datetime, field: str, - ) -> Any: + ) -> dict: # Validate field name to prevent SQL injection valid_fields = { "flow", @@ -157,13 +164,18 @@ class RealtimeRepository: raise ValueError(f"Invalid field: {field}") query = sql.SQL( - "SELECT id, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s" + "SELECT id, time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (start_time, end_time)) rows = await cur.fetchall() - return [{"id": row["id"], "value": row[field]} for row in rows] + result = defaultdict(list) + for row in rows: + result[row["id"]].append( + {"time": row["time"].isoformat(), "value": row[field]} + ) + return dict(result) @staticmethod async def update_link_field( @@ -209,7 +221,7 @@ class RealtimeRepository: async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]): if not data: return - + # 假设同一批次的数据时间是相同的 target_time = data[0]["time"] @@ -219,7 +231,7 @@ class RealtimeRepository: # 1. 先删除该时间点的旧数据 await cur.execute( "DELETE FROM realtime.node_simulation WHERE time = %s", - (target_time,) + (target_time,), ) # 2. 使用 COPY 快速写入新数据 @@ -227,20 +239,22 @@ class RealtimeRepository: "COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN" ) as copy: for item in data: - await copy.write_row(( - item["time"], - item["id"], - item.get("actual_demand"), - item.get("total_head"), - item.get("pressure"), - item.get("quality"), - )) + await copy.write_row( + ( + item["time"], + item["id"], + item.get("actual_demand"), + item.get("total_head"), + item.get("pressure"), + item.get("quality"), + ) + ) @staticmethod def insert_nodes_batch_sync(conn: Connection, data: List[dict]): if not data: return - + # 假设同一批次的数据时间是相同的 target_time = data[0]["time"] @@ -250,7 +264,7 @@ class RealtimeRepository: # 1. 先删除该时间点的旧数据 cur.execute( "DELETE FROM realtime.node_simulation WHERE time = %s", - (target_time,) + (target_time,), ) # 2. 使用 COPY 快速写入新数据 @@ -258,14 +272,16 @@ class RealtimeRepository: "COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN" ) as copy: for item in data: - copy.write_row(( - item["time"], - item["id"], - item.get("actual_demand"), - item.get("total_head"), - item.get("pressure"), - item.get("quality"), - )) + copy.write_row( + ( + item["time"], + item["id"], + item.get("actual_demand"), + item.get("total_head"), + item.get("pressure"), + item.get("quality"), + ) + ) @staticmethod async def get_node_by_time_range( @@ -296,36 +312,43 @@ class RealtimeRepository: end_time: datetime, node_id: str, field: str, - ) -> Any: + ) -> List[Dict[str, Any]]: valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( - "SELECT {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s" + "SELECT time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (start_time, end_time, node_id)) - row = await cur.fetchone() - return row[field] if row else None + rows = await cur.fetchall() + return [ + {"time": row["time"].isoformat(), "value": row[field]} for row in rows + ] @staticmethod async def get_nodes_field_by_time_range( conn: AsyncConnection, start_time: datetime, end_time: datetime, field: str - ) -> Any: + ) -> dict: valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( - "SELECT id, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s" + "SELECT id, time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (start_time, end_time)) rows = await cur.fetchall() - return [{"id": row["id"], "value": row[field]} for row in rows] + result = defaultdict(list) + for row in rows: + result[row["id"]].append( + {"time": row["time"].isoformat(), "value": row[field]} + ) + return dict(result) @staticmethod async def update_node_field( diff --git a/timescaledb/schemas/scada.py b/timescaledb/schemas/scada.py index 6ce7a8a..e879c14 100644 --- a/timescaledb/schemas/scada.py +++ b/timescaledb/schemas/scada.py @@ -1,5 +1,6 @@ from typing import List, Any from datetime import datetime +from collections import defaultdict from psycopg import AsyncConnection, Connection, sql @@ -59,19 +60,25 @@ class ScadaRepository: start_time: datetime, end_time: datetime, field: str, - ) -> List[dict]: + ) -> dict: valid_fields = {"monitored_value", "cleaned_value"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( - "SELECT device_id, {} FROM scada.scada_data WHERE time >= %s AND time <= %s AND device_id = ANY(%s)" + "SELECT device_id, time, {} FROM scada.scada_data WHERE time >= %s AND time <= %s AND device_id = ANY(%s)" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (start_time, end_time, device_ids)) rows = await cur.fetchall() - return [{"device_id": row["device_id"], field: row[field]} for row in rows] + result = defaultdict(list) + for row in rows: + result[row["device_id"]].append({ + "time": row["time"].isoformat(), + "value": row[field] + }) + return dict(result) @staticmethod async def update_scada_field( diff --git a/timescaledb/schemas/scheme.py b/timescaledb/schemas/scheme.py index e478cfe..205f6c9 100644 --- a/timescaledb/schemas/scheme.py +++ b/timescaledb/schemas/scheme.py @@ -1,5 +1,6 @@ from typing import List, Any, Dict from datetime import datetime, timedelta, timezone +from collections import defaultdict from psycopg import AsyncConnection, Connection, sql import globals @@ -135,7 +136,7 @@ class SchemeRepository: end_time: datetime, link_id: str, field: str, - ) -> Any: + ) -> List[Dict[str, Any]]: # Validate field name to prevent SQL injection valid_fields = { "flow", @@ -151,15 +152,15 @@ class SchemeRepository: raise ValueError(f"Invalid field: {field}") query = sql.SQL( - "SELECT {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s" + "SELECT time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute( query, (scheme_type, scheme_name, start_time, end_time, link_id) ) - row = await cur.fetchone() - return row[field] if row else None + rows = await cur.fetchall() + return [{"time": row["time"].isoformat(), "value": row[field]} for row in rows] @staticmethod async def get_links_field_by_scheme_and_time_range( @@ -169,7 +170,7 @@ class SchemeRepository: start_time: datetime, end_time: datetime, field: str, - ) -> Any: + ) -> dict: # Validate field name to prevent SQL injection valid_fields = { "flow", @@ -185,13 +186,19 @@ class SchemeRepository: raise ValueError(f"Invalid field: {field}") query = sql.SQL( - "SELECT id, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s" + "SELECT id, time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (scheme_type, scheme_name, start_time, end_time)) rows = await cur.fetchall() - return [{"id": row["id"], "value": row[field]} for row in rows] + result = defaultdict(list) + for row in rows: + result[row["id"]].append({ + "time": row["time"].isoformat(), + "value": row[field] + }) + return dict(result) @staticmethod async def update_link_field( @@ -353,22 +360,22 @@ class SchemeRepository: end_time: datetime, node_id: str, field: str, - ) -> Any: + ) -> List[Dict[str, Any]]: # Validate field name to prevent SQL injection valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( - "SELECT {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s" + "SELECT time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute( query, (scheme_type, scheme_name, start_time, end_time, node_id) ) - row = await cur.fetchone() - return row[field] if row else None + rows = await cur.fetchall() + return [{"time": row["time"].isoformat(), "value": row[field]} for row in rows] @staticmethod async def get_nodes_field_by_scheme_and_time_range( @@ -378,20 +385,26 @@ class SchemeRepository: start_time: datetime, end_time: datetime, field: str, - ) -> Any: + ) -> dict: # Validate field name to prevent SQL injection valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( - "SELECT id, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s" + "SELECT id, time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (scheme_type, scheme_name, start_time, end_time)) rows = await cur.fetchall() - return [{"id": row["id"], "value": row[field]} for row in rows] + result = defaultdict(list) + for row in rows: + result[row["id"]].append({ + "time": row["time"].isoformat(), + "value": row[field] + }) + return dict(result) @staticmethod async def update_node_field(