from typing import List, Any from datetime import datetime from collections import defaultdict from psycopg import AsyncConnection, Connection, sql class ScadaRepository: @staticmethod async def insert_scada_batch(conn: AsyncConnection, data: List[dict]): if not data: return async with conn.cursor() as cur: async with cur.copy( "COPY scada.scada_data (time, device_id, monitored_value, cleaned_value) FROM STDIN" ) as copy: for item in data: await copy.write_row( ( item["time"], item["device_id"], item.get("monitored_value"), item.get("cleaned_value"), ) ) @staticmethod async def get_scada_by_ids_time_range( conn: AsyncConnection, device_ids: List[str], start_time: datetime, end_time: datetime, ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s", (device_ids, start_time, end_time), ) return await cur.fetchall() @staticmethod def get_scada_by_ids_time_range_sync( conn: Connection, device_ids: List[str], start_time: datetime, end_time: datetime, ) -> List[dict]: with conn.cursor() as cur: cur.execute( "SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s", (device_ids, start_time, end_time), ) return cur.fetchall() @staticmethod async def get_scada_field_by_id_time_range( conn: AsyncConnection, device_ids: List[str], start_time: datetime, end_time: datetime, field: str, ) -> dict: valid_fields = {"monitored_value", "cleaned_value"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "SELECT device_id, time, {} FROM scada.scada_data WHERE time >= %s AND time <= %s AND device_id = ANY(%s)" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (start_time, end_time, device_ids)) rows = await cur.fetchall() result = defaultdict(list) for row in rows: result[row["device_id"]].append({ "time": row["time"].isoformat(), "value": row[field] }) return dict(result) @staticmethod async def update_scada_field( conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any ): valid_fields = {"monitored_value", "cleaned_value"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") query = sql.SQL( "UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s" ).format(sql.Identifier(field)) async with conn.cursor() as cur: await cur.execute(query, (value, time, device_id)) @staticmethod async def delete_scada_by_id_time_range( conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime ): async with conn.cursor() as cur: await cur.execute( "DELETE FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s", (device_id, start_time, end_time), )