diff --git a/main.py b/main.py index fa7809b..683e5ca 100644 --- a/main.py +++ b/main.py @@ -3123,6 +3123,8 @@ async def fastapi_query_cleaning_scada_data_by_device_id_and_time_range( return influxdb_api.query_cleaning_SCADA_data_by_device_ID_and_timerange( query_ids_list=query_ids, start_time=starttime, end_time=endtime ) + + # 查询到的SCADA模拟数据(从 realtime_simulation bucket 中查找) @app.get("/querysimulationscadadatabydeviceidandtimerange/") async def fastapi_query_simulation_scada_data_by_device_id_and_time_range( @@ -3447,7 +3449,7 @@ async def fastapi_run_simulation_manually_by_date( item["name"], region_result ) ) - + ( globals.source_outflow_region_patterns, globals.realtime_region_pipe_flow_and_demand_patterns, @@ -4212,9 +4214,10 @@ if __name__ == "__main__": # uvicorn.run(app, host="0.0.0.0", port=8000) # url='http://127.0.0.1:8000/valve_close_analysis?network=beibeizone&start_time=2024-04-01T08:00:00Z&valve_IDs=GSD2307192058577780A3287D78&valve_IDs=GSD2307192058572E953B707226(S2)&duration=1800' # url='http://127.0.0.1:8000/burst_analysis?network=beibeizone&start_time=2024-04-01T08:00:00Z&burst_ID=ZBBGXSZW000001&duration=1800' - url = "http://192.168.1.36:8000/queryallschemeallrecords/?schemename=Fangan0817114448&querydate=2025-08-13&schemetype=burst_Analysis" + # url = "http://192.168.1.36:8000/queryallschemeallrecords/?schemename=Fangan0817114448&querydate=2025-08-13&schemetype=burst_Analysis" # response = Request.get(url) - import requests + # import requests - response = requests.get(url) + # response = requests.get(url) + print(get_all_scada_info("szh")) diff --git a/postgresql/database.py b/postgresql/database.py new file mode 100644 index 0000000..fae2904 --- /dev/null +++ b/postgresql/database.py @@ -0,0 +1,104 @@ +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator, Dict, Optional +import psycopg_pool +from psycopg.rows import dict_row +import postgresql_info + +# Configure logging +logger = logging.getLogger(__name__) + + +class Database: + def __init__(self, db_name=None): + self.pool = None + self.db_name = db_name + + def init_pool(self, db_name=None): + """Initialize the connection pool.""" + # Use provided db_name, or the one from constructor, or default from config + target_db_name = db_name or self.db_name + conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name) + try: + self.pool = psycopg_pool.AsyncConnectionPool( + conninfo=conn_string, + min_size=1, + max_size=20, + open=False, # Don't open immediately, wait for startup + kwargs={"row_factory": dict_row}, # Return rows as dictionaries + ) + logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}") + except Exception as e: + logger.error(f"Failed to initialize postgresql connection pool: {e}") + raise + + async def open(self): + if self.pool: + await self.pool.open() + + async def close(self): + """Close the connection pool.""" + if self.pool: + await self.pool.close() + logger.info("PostgreSQL connection pool closed.") + + @asynccontextmanager + async def get_connection(self) -> AsyncGenerator: + """Get a connection from the pool.""" + if not self.pool: + raise Exception("Database pool is not initialized.") + + async with self.pool.connection() as conn: + yield conn + + +# 默认数据库实例 +db = Database() + +# 缓存不同数据库的实例 - 避免重复创建连接池 +_database_instances: Dict[str, Database] = {} + +def create_database_instance(db_name): + """Create a new Database instance for a specific database.""" + return Database(db_name=db_name) + +async def get_database_instance(db_name: Optional[str] = None) -> Database: + """Get or create a database instance for the specified database name.""" + if not db_name: + return db # 返回默认数据库实例 + + if db_name not in _database_instances: + # 创建新的数据库实例 + instance = create_database_instance(db_name) + instance.init_pool() + await instance.open() + _database_instances[db_name] = instance + logger.info(f"Created new database instance for: {db_name}") + + return _database_instances[db_name] + +async def get_db_connection(): + """Dependency for FastAPI to get a database connection.""" + async with db.get_connection() as conn: + yield conn + +async def get_database_connection(db_name: Optional[str] = None): + """ + FastAPI dependency to get database connection with optional database name. + 使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name")) + 或在路由函数中: conn: AsyncConnection = Depends(get_database_connection) + """ + instance = await get_database_instance(db_name) + async with instance.get_connection() as conn: + yield conn + +async def cleanup_database_instances(): + """Clean up all database instances (call this on application shutdown).""" + for db_name, instance in _database_instances.items(): + await instance.close() + logger.info(f"Closed database instance for: {db_name}") + _database_instances.clear() + + # 关闭默认数据库 + await db.close() + logger.info("All database instances cleaned up.") diff --git a/postgresql/router.py b/postgresql/router.py new file mode 100644 index 0000000..1c6ed97 --- /dev/null +++ b/postgresql/router.py @@ -0,0 +1,40 @@ +from fastapi import APIRouter, Depends, HTTPException, Query +from typing import List, Optional +from datetime import datetime +from psycopg import AsyncConnection + +from .database import get_database_instance +from .scada_info import query_pg_scada_info + +router = APIRouter(prefix="/postgresql", tags=["postgresql"]) + + +# 创建支持数据库选择的连接依赖函数 +async def get_database_connection( + db_name: Optional[str] = Query( + None, description="指定要连接的数据库名称,为空时使用默认数据库" + ) +): + """获取数据库连接,支持通过查询参数指定数据库名称""" + instance = await get_database_instance(db_name) + async with instance.get_connection() as conn: + yield conn + + +@router.get("/scada-info") +async def get_scada_info_with_connection( + conn: AsyncConnection = Depends(get_database_connection), +): + """ + 使用连接池查询SCADA信息 + """ + try: + # 使用连接查询SCADA信息 + async with conn.cursor() as cur: + await cur.execute("SELECT * FROM scada_info") + scada_data = await cur.fetchall() + return {"success": True, "data": scada_data, "count": len(scada_data)} + except Exception as e: + raise HTTPException( + status_code=500, detail=f"查询SCADA信息时发生错误: {str(e)}" + ) \ No newline at end of file diff --git a/postgresql/scada_info.py b/postgresql/scada_info.py new file mode 100644 index 0000000..b96af6b --- /dev/null +++ b/postgresql/scada_info.py @@ -0,0 +1,35 @@ +from typing import List, Optional, Any +from psycopg import AsyncConnection + + +class ScadaRepository: + + @staticmethod + async def get_scadas_info(conn: AsyncConnection) -> List[dict]: + """ + 查询pg数据库中,scada_info 的所有记录 + :param conn: 异步数据库连接 + :return: 包含所有记录的列表,每条记录为一个字典 + """ + async with conn.cursor() as cur: + await cur.execute( + """ + SELECT id, type, transmission_mode, transmission_frequency, reliability + FROM public.scada_info + """ + ) + records = await cur.fetchall() + + # 将查询结果转换为字典列表 + records_list = [] + for record in records: + record_dict = { + "id": record[0], + "type": record[1], + "transmission_mode": record[2], + "transmission_frequency": record[3], + "reliability": record[4], + } + records_list.append(record_dict) + + return records_list diff --git a/script/package/PyMetis-2018.1-cp34-cp34m-win_amd64.whl b/script/package/PyMetis-2018.1-cp34-cp34m-win_amd64.whl deleted file mode 100644 index eb0f989..0000000 Binary files a/script/package/PyMetis-2018.1-cp34-cp34m-win_amd64.whl and /dev/null differ diff --git a/script/package/PyMetis-2019.1.1-cp35-cp35m-win_amd64.whl b/script/package/PyMetis-2019.1.1-cp35-cp35m-win_amd64.whl deleted file mode 100644 index c9c7475..0000000 Binary files a/script/package/PyMetis-2019.1.1-cp35-cp35m-win_amd64.whl and /dev/null differ diff --git a/script/package/PyMetis-2019.1.1-cp36-cp36m-win_amd64.whl b/script/package/PyMetis-2019.1.1-cp36-cp36m-win_amd64.whl deleted file mode 100644 index 139a7a7..0000000 Binary files a/script/package/PyMetis-2019.1.1-cp36-cp36m-win_amd64.whl and /dev/null differ diff --git a/script/package/PyMetis-2020.1-cp310-cp310-win_amd64.whl b/script/package/PyMetis-2020.1-cp310-cp310-win_amd64.whl deleted file mode 100644 index b6f524c..0000000 Binary files a/script/package/PyMetis-2020.1-cp310-cp310-win_amd64.whl and /dev/null differ diff --git a/script/package/PyMetis-2020.1-cp37-cp37m-win_amd64.whl b/script/package/PyMetis-2020.1-cp37-cp37m-win_amd64.whl deleted file mode 100644 index 076c417..0000000 Binary files a/script/package/PyMetis-2020.1-cp37-cp37m-win_amd64.whl and /dev/null differ diff --git a/script/package/PyMetis-2020.1-cp38-cp38-win_amd64.whl b/script/package/PyMetis-2020.1-cp38-cp38-win_amd64.whl deleted file mode 100644 index 11b739e..0000000 Binary files a/script/package/PyMetis-2020.1-cp38-cp38-win_amd64.whl and /dev/null differ diff --git a/script/package/PyMetis-2020.1-cp39-cp39-win_amd64.whl b/script/package/PyMetis-2020.1-cp39-cp39-win_amd64.whl deleted file mode 100644 index e77bcee..0000000 Binary files a/script/package/PyMetis-2020.1-cp39-cp39-win_amd64.whl and /dev/null differ diff --git a/script/package/pkg-pymetis-2023.1.1-py312h95578b8_4.tar b/script/package/pkg-pymetis-2023.1.1-py312h95578b8_4.tar deleted file mode 100644 index 0dd6b78..0000000 Binary files a/script/package/pkg-pymetis-2023.1.1-py312h95578b8_4.tar and /dev/null differ diff --git a/software/pg14/pg14.z01 b/software/pg14/pg14.z01 deleted file mode 100644 index 60738c5..0000000 Binary files a/software/pg14/pg14.z01 and /dev/null differ diff --git a/software/pg14/pg14.z02 b/software/pg14/pg14.z02 deleted file mode 100644 index 7df62bb..0000000 Binary files a/software/pg14/pg14.z02 and /dev/null differ diff --git a/software/pg14/pg14.z03 b/software/pg14/pg14.z03 deleted file mode 100644 index 4cc4179..0000000 Binary files a/software/pg14/pg14.z03 and /dev/null differ diff --git a/software/pg14/pg14.z04 b/software/pg14/pg14.z04 deleted file mode 100644 index 2b69c53..0000000 Binary files a/software/pg14/pg14.z04 and /dev/null differ diff --git a/software/pg14/pg14.z05 b/software/pg14/pg14.z05 deleted file mode 100644 index 57dcaec..0000000 Binary files a/software/pg14/pg14.z05 and /dev/null differ diff --git a/software/pg14/pg14.z06 b/software/pg14/pg14.z06 deleted file mode 100644 index 5eb36ed..0000000 Binary files a/software/pg14/pg14.z06 and /dev/null differ diff --git a/software/pg14/pg14.z07 b/software/pg14/pg14.z07 deleted file mode 100644 index b1afbc8..0000000 Binary files a/software/pg14/pg14.z07 and /dev/null differ diff --git a/software/pg14/pg14.zip b/software/pg14/pg14.zip deleted file mode 100644 index 06ac07d..0000000 Binary files a/software/pg14/pg14.zip and /dev/null differ diff --git a/timescaledb/__init__.py b/timescaledb/__init__.py index e69de29..398a9b5 100644 --- a/timescaledb/__init__.py +++ b/timescaledb/__init__.py @@ -0,0 +1,3 @@ +from .router import router +from .database import * +from .timescaledb_info import * \ No newline at end of file diff --git a/timescaledb/database.py b/timescaledb/database.py index 8778536..b7bce40 100644 --- a/timescaledb/database.py +++ b/timescaledb/database.py @@ -1,6 +1,6 @@ import logging from contextlib import asynccontextmanager -from typing import AsyncGenerator +from typing import AsyncGenerator, Dict, Optional import psycopg_pool from psycopg.rows import dict_row import timescaledb.timescaledb_info as timescaledb_info @@ -9,12 +9,15 @@ import timescaledb.timescaledb_info as timescaledb_info logger = logging.getLogger(__name__) class Database: - def __init__(self): + def __init__(self, db_name=None): self.pool = None + self.db_name = db_name - def init_pool(self): + def init_pool(self, db_name=None): """Initialize the connection pool.""" - conn_string = timescaledb_info.get_pgconn_string() + # Use provided db_name, or the one from constructor, or default from config + target_db_name = db_name or self.db_name + conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name) try: self.pool = psycopg_pool.AsyncConnectionPool( conninfo=conn_string, @@ -23,7 +26,7 @@ class Database: open=False, # Don't open immediately, wait for startup kwargs={"row_factory": dict_row} # Return rows as dictionaries ) - logger.info("TimescaleDB connection pool initialized.") + logger.info(f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}") except Exception as e: logger.error(f"Failed to initialize TimescaleDB connection pool: {e}") raise @@ -47,9 +50,53 @@ class Database: async with self.pool.connection() as conn: yield conn +# 默认数据库实例 db = Database() +# 缓存不同数据库的实例 - 避免重复创建连接池 +_database_instances: Dict[str, Database] = {} + +def create_database_instance(db_name): + """Create a new Database instance for a specific database.""" + return Database(db_name=db_name) + +async def get_database_instance(db_name: Optional[str] = None) -> Database: + """Get or create a database instance for the specified database name.""" + if not db_name: + return db # 返回默认数据库实例 + + if db_name not in _database_instances: + # 创建新的数据库实例 + instance = create_database_instance(db_name) + instance.init_pool() + await instance.open() + _database_instances[db_name] = instance + logger.info(f"Created new database instance for: {db_name}") + + return _database_instances[db_name] + async def get_db_connection(): """Dependency for FastAPI to get a database connection.""" async with db.get_connection() as conn: yield conn + +async def get_database_connection(db_name: Optional[str] = None): + """ + FastAPI dependency to get database connection with optional database name. + 使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name")) + 或在路由函数中: conn: AsyncConnection = Depends(get_database_connection) + """ + instance = await get_database_instance(db_name) + async with instance.get_connection() as conn: + yield conn + +async def cleanup_database_instances(): + """Clean up all database instances (call this on application shutdown).""" + for db_name, instance in _database_instances.items(): + await instance.close() + logger.info(f"Closed database instance for: {db_name}") + _database_instances.clear() + + # 关闭默认数据库 + await db.close() + logger.info("All database instances cleaned up.") diff --git a/timescaledb/router.py b/timescaledb/router.py index 5c28861..7427b75 100644 --- a/timescaledb/router.py +++ b/timescaledb/router.py @@ -1,40 +1,65 @@ from fastapi import APIRouter, Depends, HTTPException, Query -from typing import List, Any, Dict +from typing import List, Optional from datetime import datetime from psycopg import AsyncConnection -from .database import get_db_connection +from .database import get_database_instance from .schemas.realtime import RealtimeRepository from .schemas.scheme import SchemeRepository from .schemas.scada import ScadaRepository router = APIRouter(prefix="/timescaledb", tags=["TimescaleDB"]) + +# 创建支持数据库选择的连接依赖函数 +async def get_database_connection( + db_name: Optional[str] = Query( + None, description="指定要连接的数据库名称,为空时使用默认数据库" + ) +): + """获取数据库连接,支持通过查询参数指定数据库名称""" + instance = await get_database_instance(db_name) + async with instance.get_connection() as conn: + yield conn + + # --- Realtime Endpoints --- + @router.post("/realtime/links/batch", status_code=201) async def insert_realtime_links( - data: List[dict], - conn: AsyncConnection = Depends(get_db_connection) + data: List[dict], conn: AsyncConnection = Depends(get_database_connection) ): await RealtimeRepository.insert_links_batch(conn, data) return {"message": f"Inserted {len(data)} records"} + @router.get("/realtime/links") async def get_realtime_links( start_time: datetime, end_time: datetime, - conn: AsyncConnection = Depends(get_db_connection) + conn: AsyncConnection = Depends(get_database_connection), ): return await RealtimeRepository.get_links_by_time(conn, start_time, end_time) + +@router.delete("/realtime/links") +async def delete_realtime_links( + start_time: datetime, + end_time: datetime, + conn: AsyncConnection = Depends(get_database_connection), +): + await RealtimeRepository.delete_links_by_time(conn, start_time, end_time) + return {"message": "Deleted successfully"} + + @router.patch("/realtime/links/{link_id}/field") async def update_realtime_link_field( link_id: str, time: datetime, field: str, - value: float, # Assuming float for now, could be Any but FastAPI needs type - conn: AsyncConnection = Depends(get_db_connection) + value: float, # Assuming float for now, could be Any but FastAPI needs type + conn: AsyncConnection = Depends(get_database_connection), ): try: await RealtimeRepository.update_link_field(conn, time, link_id, field, value) @@ -42,64 +67,214 @@ async def update_realtime_link_field( except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + @router.post("/realtime/nodes/batch", status_code=201) async def insert_realtime_nodes( - data: List[dict], - conn: AsyncConnection = Depends(get_db_connection) + data: List[dict], conn: AsyncConnection = Depends(get_database_connection) ): await RealtimeRepository.insert_nodes_batch(conn, data) return {"message": f"Inserted {len(data)} records"} + @router.get("/realtime/nodes") async def get_realtime_nodes( start_time: datetime, end_time: datetime, - conn: AsyncConnection = Depends(get_db_connection) + conn: AsyncConnection = Depends(get_database_connection), ): return await RealtimeRepository.get_nodes_by_time(conn, start_time, end_time) + +@router.delete("/realtime/nodes") +async def delete_realtime_nodes( + start_time: datetime, + end_time: datetime, + conn: AsyncConnection = Depends(get_database_connection), +): + await RealtimeRepository.delete_nodes_by_time(conn, start_time, end_time) + return {"message": "Deleted successfully"} + + # --- Scheme Endpoints --- + @router.post("/scheme/links/batch", status_code=201) async def insert_scheme_links( - data: List[dict], - conn: AsyncConnection = Depends(get_db_connection) + data: List[dict], conn: AsyncConnection = Depends(get_database_connection) ): await SchemeRepository.insert_links_batch(conn, data) return {"message": f"Inserted {len(data)} records"} + @router.get("/scheme/links") async def get_scheme_links( scheme: str, start_time: datetime, end_time: datetime, - conn: AsyncConnection = Depends(get_db_connection) + conn: AsyncConnection = Depends(get_database_connection), ): - return await SchemeRepository.get_links_by_scheme_and_time(conn, scheme, start_time, end_time) + return await SchemeRepository.get_links_by_scheme_and_time( + conn, scheme, start_time, end_time + ) + + +@router.get("/scheme/links/{link_id}/field") +async def get_scheme_link_field( + scheme: str, + link_id: str, + time: datetime, + field: str, + conn: AsyncConnection = Depends(get_database_connection), +): + try: + return await SchemeRepository.get_link_field_by_scheme_and_time( + conn, time, scheme, link_id, field + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.patch("/scheme/links/{link_id}/field") +async def update_scheme_link_field( + scheme: str, + link_id: str, + time: datetime, + field: str, + value: float, + conn: AsyncConnection = Depends(get_database_connection), +): + try: + await SchemeRepository.update_link_field( + conn, time, scheme, link_id, field, value + ) + return {"message": "Updated successfully"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/scheme/links") +async def delete_scheme_links( + scheme: str, + start_time: datetime, + end_time: datetime, + conn: AsyncConnection = Depends(get_database_connection), +): + await SchemeRepository.delete_links_by_scheme_and_time( + conn, scheme, start_time, end_time + ) + return {"message": "Deleted successfully"} + @router.post("/scheme/nodes/batch", status_code=201) async def insert_scheme_nodes( - data: List[dict], - conn: AsyncConnection = Depends(get_db_connection) + data: List[dict], conn: AsyncConnection = Depends(get_database_connection) ): await SchemeRepository.insert_nodes_batch(conn, data) return {"message": f"Inserted {len(data)} records"} + +@router.get("/scheme/nodes/{node_id}/field") +async def get_scheme_node_field( + scheme: str, + node_id: str, + time: datetime, + field: str, + conn: AsyncConnection = Depends(get_database_connection), +): + try: + return await SchemeRepository.get_node_field_by_scheme_and_time( + conn, time, scheme, node_id, field + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.patch("/scheme/nodes/{node_id}/field") +async def update_scheme_node_field( + scheme: str, + node_id: str, + time: datetime, + field: str, + value: float, + conn: AsyncConnection = Depends(get_database_connection), +): + try: + await SchemeRepository.update_node_field( + conn, time, scheme, node_id, field, value + ) + return {"message": "Updated successfully"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/scheme/nodes") +async def delete_scheme_nodes( + scheme: str, + start_time: datetime, + end_time: datetime, + conn: AsyncConnection = Depends(get_database_connection), +): + await SchemeRepository.delete_nodes_by_scheme_and_time( + conn, scheme, start_time, end_time + ) + return {"message": "Deleted successfully"} + + # --- SCADA Endpoints --- + @router.post("/scada/batch", status_code=201) async def insert_scada_data( - data: List[dict], - conn: AsyncConnection = Depends(get_db_connection) + data: List[dict], conn: AsyncConnection = Depends(get_database_connection) ): await ScadaRepository.insert_batch(conn, data) return {"message": f"Inserted {len(data)} records"} + @router.get("/scada") async def get_scada_data( device_id: str, start_time: datetime, end_time: datetime, - conn: AsyncConnection = Depends(get_db_connection) + conn: AsyncConnection = Depends(get_database_connection), ): return await ScadaRepository.get_data_by_time(conn, device_id, start_time, end_time) + + +@router.get("/scada/{device_id}/field") +async def get_scada_field( + device_id: str, + time: datetime, + field: str, + conn: AsyncConnection = Depends(get_database_connection), +): + try: + return await ScadaRepository.get_field(conn, time, device_id, field) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.patch("/scada/{device_id}/field") +async def update_scada_field( + device_id: str, + time: datetime, + field: str, + value: float, + conn: AsyncConnection = Depends(get_database_connection), +): + try: + await ScadaRepository.update_field(conn, time, device_id, field, value) + return {"message": "Updated successfully"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/scada") +async def delete_scada_data( + device_id: str, + start_time: datetime, + end_time: datetime, + conn: AsyncConnection = Depends(get_database_connection), +): + await ScadaRepository.delete_data_by_time(conn, device_id, start_time, end_time) + return {"message": "Deleted successfully"} diff --git a/timescaledb/schemas/realtime.py b/timescaledb/schemas/realtime.py index 9fda2ab..d42c3f1 100644 --- a/timescaledb/schemas/realtime.py +++ b/timescaledb/schemas/realtime.py @@ -2,10 +2,11 @@ from typing import List, Any, Optional from datetime import datetime from psycopg import AsyncConnection, sql + class RealtimeRepository: - + # --- Link Simulation --- - + @staticmethod async def insert_links_batch(conn: AsyncConnection, data: List[dict]): """Batch insert for realtime.link_simulation using COPY for performance.""" @@ -17,51 +18,141 @@ class RealtimeRepository: "COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN" ) as copy: for item in data: - await copy.write_row(( - item['time'], item['id'], item.get('flow'), item.get('friction'), item.get('headloss'), - item.get('quality'), item.get('reaction'), item.get('setting'), item.get('status'), item.get('velocity') - )) + await copy.write_row( + ( + item["time"], + item["id"], + item.get("flow"), + item.get("friction"), + item.get("headloss"), + item.get("quality"), + item.get("reaction"), + item.get("setting"), + item.get("status"), + item.get("velocity"), + ) + ) @staticmethod - async def get_links_by_time(conn: AsyncConnection, start_time: datetime, end_time: datetime) -> List[dict]: + async def get_link_by_time_range( + conn: AsyncConnection, start_time: datetime, end_time: datetime, link_id: str + ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( - "SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s", - (start_time, end_time) + "SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s", + (start_time, end_time, link_id), ) return await cur.fetchall() @staticmethod - async def get_link_field(conn: AsyncConnection, time: datetime, link_id: str, field: str) -> Any: + async def get_links_by_time_range( + conn: AsyncConnection, start_time: datetime, end_time: datetime + ) -> List[dict]: + async with conn.cursor() as cur: + await cur.execute( + "SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s", + (start_time, end_time), + ) + return await cur.fetchall() + + @staticmethod + async def get_link_field_by_time_range( + conn: AsyncConnection, + start_time: datetime, + end_time: datetime, + link_id: str, + field: str, + ) -> Any: # Validate field name to prevent SQL injection - valid_fields = {"flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity"} + valid_fields = { + "flow", + "friction", + "headloss", + "quality", + "reaction", + "setting", + "status", + "velocity", + } if field not in valid_fields: raise ValueError(f"Invalid field: {field}") - query = sql.SQL("SELECT {} FROM realtime.link_simulation WHERE time = %s AND id = %s").format(sql.Identifier(field)) - + query = sql.SQL( + "SELECT {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s" + ).format(sql.Identifier(field)) + async with conn.cursor() as cur: - await cur.execute(query, (time, link_id)) + await cur.execute(query, (start_time, end_time, link_id)) row = await cur.fetchone() return row[field] if row else None @staticmethod - async def update_link_field(conn: AsyncConnection, time: datetime, link_id: str, field: str, value: Any): - valid_fields = {"flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity"} + async def get_links_field_by_time_range( + conn: AsyncConnection, + start_time: datetime, + end_time: datetime, + link_id: str, + field: str, + ) -> Any: + # Validate field name to prevent SQL injection + valid_fields = { + "flow", + "friction", + "headloss", + "quality", + "reaction", + "setting", + "status", + "velocity", + } if field not in valid_fields: raise ValueError(f"Invalid field: {field}") - query = sql.SQL("UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s").format(sql.Identifier(field)) - + query = sql.SQL( + "SELECT {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s" + ).format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (start_time, end_time)) + row = await cur.fetchone() + return row[field] if row else None + + @staticmethod + async def update_link_field( + conn: AsyncConnection, + time: datetime, + link_id: str, + field: str, + value: Any, + ): + valid_fields = { + "flow", + "friction", + "headloss", + "quality", + "reaction", + "setting", + "status", + "velocity", + } + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL( + "UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s" + ).format(sql.Identifier(field)) + async with conn.cursor() as cur: await cur.execute(query, (value, time, link_id)) @staticmethod - async def delete_links_by_time(conn: AsyncConnection, start_time: datetime, end_time: datetime): + async def delete_links_by_time_range( + conn: AsyncConnection, start_time: datetime, end_time: datetime + ): async with conn.cursor() as cur: await cur.execute( "DELETE FROM realtime.link_simulation WHERE time >= %s AND time <= %s", - (start_time, end_time) + (start_time, end_time), ) # --- Node Simulation --- @@ -76,39 +167,102 @@ class RealtimeRepository: "COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN" ) as copy: for item in data: - await copy.write_row(( - item['time'], item['id'], item.get('actual_demand'), item.get('total_head'), item.get('pressure'), item.get('quality') - )) + await copy.write_row( + ( + item["time"], + item["id"], + item.get("actual_demand"), + item.get("total_head"), + item.get("pressure"), + item.get("quality"), + ) + ) @staticmethod - async def get_nodes_by_time(conn: AsyncConnection, start_time: datetime, end_time: datetime) -> List[dict]: + async def get_node_by_time_range( + conn: AsyncConnection, start_time: datetime, end_time: datetime, node_id: str + ) -> List[dict]: + async with conn.cursor() as cur: + await cur.execute( + "SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s", + (start_time, end_time, node_id), + ) + return await cur.fetchall() + + @staticmethod + async def get_nodes_by_time_range( + conn: AsyncConnection, start_time: datetime, end_time: datetime + ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s", - (start_time, end_time) + (start_time, end_time), ) return await cur.fetchall() - + @staticmethod - async def get_node_field(conn: AsyncConnection, time: datetime, node_id: str, field: str) -> Any: + async def get_node_field_by_time_range( + conn: AsyncConnection, + start_time: datetime, + end_time: datetime, + node_id: str, + field: str, + ) -> Any: valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") - query = sql.SQL("SELECT {} FROM realtime.node_simulation WHERE time = %s AND id = %s").format(sql.Identifier(field)) - + query = sql.SQL( + "SELECT {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s" + ).format(sql.Identifier(field)) + async with conn.cursor() as cur: - await cur.execute(query, (time, node_id)) + await cur.execute(query, (start_time, end_time, node_id)) row = await cur.fetchone() return row[field] if row else None @staticmethod - async def update_node_field(conn: AsyncConnection, time: datetime, node_id: str, field: str, value: Any): + async def get_nodes_field_by_time_range( + conn: AsyncConnection, start_time: datetime, end_time: datetime, field: str + ) -> Any: valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") - query = sql.SQL("UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s").format(sql.Identifier(field)) - + query = sql.SQL( + "SELECT {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s" + ).format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (start_time, end_time)) + row = await cur.fetchone() + return row[field] if row else None + + @staticmethod + async def update_node_field( + conn: AsyncConnection, + time: datetime, + node_id: str, + field: str, + value: Any, + ): + valid_fields = {"actual_demand", "total_head", "pressure", "quality"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL( + "UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s" + ).format(sql.Identifier(field)) + async with conn.cursor() as cur: await cur.execute(query, (value, time, node_id)) + + @staticmethod + async def delete_nodes_by_time_range( + conn: AsyncConnection, start_time: datetime, end_time: datetime + ): + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM realtime.node_simulation WHERE time >= %s AND time <= %s", + (start_time, end_time), + ) diff --git a/timescaledb/schemas/scada.py b/timescaledb/schemas/scada.py index a4c4202..8e9396e 100644 --- a/timescaledb/schemas/scada.py +++ b/timescaledb/schemas/scada.py @@ -2,10 +2,11 @@ from typing import List, Any from datetime import datetime from psycopg import AsyncConnection, sql + class ScadaRepository: - + @staticmethod - async def insert_batch(conn: AsyncConnection, data: List[dict]): + async def insert_scada_batch(conn: AsyncConnection, data: List[dict]): if not data: return @@ -14,26 +15,64 @@ class ScadaRepository: "COPY scada.scada_data (time, device_id, monitored_value, cleaned_value) FROM STDIN" ) as copy: for item in data: - await copy.write_row(( - item['time'], item['device_id'], item.get('monitored_value'), item.get('cleaned_value') - )) + await copy.write_row( + ( + item["time"], + item["device_id"], + item.get("monitored_value"), + item.get("cleaned_value"), + ) + ) @staticmethod - async def get_data_by_time(conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime) -> List[dict]: + async def get_scada_by_id_time( + conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime + ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( "SELECT * FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s", - (device_id, start_time, end_time) + (device_id, start_time, end_time), ) return await cur.fetchall() @staticmethod - async def update_field(conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any): + async def get_scada_field_by_id_time( + conn: AsyncConnection, time: datetime, device_id: str, field: str + ) -> Any: valid_fields = {"monitored_value", "cleaned_value"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") - query = sql.SQL("UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s").format(sql.Identifier(field)) - + query = sql.SQL( + "SELECT {} FROM scada.scada_data WHERE time = %s AND device_id = %s" + ).format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (time, device_id)) + row = await cur.fetchone() + return row[field] if row else None + + @staticmethod + async def update_scada_field( + conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any + ): + valid_fields = {"monitored_value", "cleaned_value"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL( + "UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s" + ).format(sql.Identifier(field)) + async with conn.cursor() as cur: await cur.execute(query, (value, time, device_id)) + + @staticmethod + async def delete_scada_by_id_time( + conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime + ): + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s", + (device_id, start_time, end_time), + ) diff --git a/timescaledb/schemas/scheme.py b/timescaledb/schemas/scheme.py index ef699e8..f1fd5a6 100644 --- a/timescaledb/schemas/scheme.py +++ b/timescaledb/schemas/scheme.py @@ -1,13 +1,15 @@ -from typing import List, Any +from typing import List, Any, Optional from datetime import datetime from psycopg import AsyncConnection, sql + class SchemeRepository: - + # --- Link Simulation --- - + @staticmethod async def insert_links_batch(conn: AsyncConnection, data: List[dict]): + """Batch insert for scheme.link_simulation using COPY for performance.""" if not data: return @@ -16,31 +18,150 @@ class SchemeRepository: "COPY scheme.link_simulation (time, scheme, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN" ) as copy: for item in data: - await copy.write_row(( - item['time'], item['scheme'], item['id'], item.get('flow'), item.get('friction'), item.get('headloss'), - item.get('quality'), item.get('reaction'), item.get('setting'), item.get('status'), item.get('velocity') - )) + await copy.write_row( + ( + item["time"], + item["scheme"], + item["id"], + item.get("flow"), + item.get("friction"), + item.get("headloss"), + item.get("quality"), + item.get("reaction"), + item.get("setting"), + item.get("status"), + item.get("velocity"), + ) + ) @staticmethod - async def get_links_by_scheme_and_time(conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime) -> List[dict]: + async def get_link_by_scheme_and_time_range( + conn: AsyncConnection, + scheme: str, + start_time: datetime, + end_time: datetime, + link_id: str, + ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( - "SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s", - (scheme, start_time, end_time) + "SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s", + (scheme, start_time, end_time, link_id), ) return await cur.fetchall() @staticmethod - async def update_link_field(conn: AsyncConnection, time: datetime, scheme: str, link_id: str, field: str, value: Any): - valid_fields = {"flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity"} + async def get_links_by_scheme_and_time_range( + conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime + ) -> List[dict]: + async with conn.cursor() as cur: + await cur.execute( + "SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s", + (scheme, start_time, end_time), + ) + return await cur.fetchall() + + @staticmethod + async def get_link_field_by_scheme_and_time_range( + conn: AsyncConnection, + scheme: str, + start_time: datetime, + end_time: datetime, + link_id: str, + field: str, + ) -> Any: + # Validate field name to prevent SQL injection + valid_fields = { + "flow", + "friction", + "headloss", + "quality", + "reaction", + "setting", + "status", + "velocity", + } if field not in valid_fields: raise ValueError(f"Invalid field: {field}") - query = sql.SQL("UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s").format(sql.Identifier(field)) - + query = sql.SQL( + "SELECT {} FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s" + ).format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (scheme, start_time, end_time, link_id)) + row = await cur.fetchone() + return row[field] if row else None + + @staticmethod + async def get_links_field_by_scheme_and_time_range( + conn: AsyncConnection, + scheme: str, + start_time: datetime, + end_time: datetime, + field: str, + ) -> Any: + # Validate field name to prevent SQL injection + valid_fields = { + "flow", + "friction", + "headloss", + "quality", + "reaction", + "setting", + "status", + "velocity", + } + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL( + "SELECT {} FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s" + ).format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (scheme, start_time, end_time)) + row = await cur.fetchone() + return row[field] if row else None + + @staticmethod + async def update_link_field( + conn: AsyncConnection, + time: datetime, + scheme: str, + link_id: str, + field: str, + value: Any, + ): + valid_fields = { + "flow", + "friction", + "headloss", + "quality", + "reaction", + "setting", + "status", + "velocity", + } + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL( + "UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s" + ).format(sql.Identifier(field)) + async with conn.cursor() as cur: await cur.execute(query, (value, time, scheme, link_id)) + @staticmethod + async def delete_links_by_scheme_and_time_range( + conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime + ): + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s", + (scheme, start_time, end_time), + ) + # --- Node Simulation --- @staticmethod @@ -53,26 +174,115 @@ class SchemeRepository: "COPY scheme.node_simulation (time, scheme, id, actual_demand, total_head, pressure, quality) FROM STDIN" ) as copy: for item in data: - await copy.write_row(( - item['time'], item['scheme'], item['id'], item.get('actual_demand'), item.get('total_head'), item.get('pressure'), item.get('quality') - )) + await copy.write_row( + ( + item["time"], + item["scheme"], + item["id"], + item.get("actual_demand"), + item.get("total_head"), + item.get("pressure"), + item.get("quality"), + ) + ) @staticmethod - async def get_nodes_by_scheme_and_time(conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime) -> List[dict]: + async def get_node_by_scheme_and_time_range( + conn: AsyncConnection, + scheme: str, + start_time: datetime, + end_time: datetime, + node_id: str, + ) -> List[dict]: async with conn.cursor() as cur: await cur.execute( - "SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s", - (scheme, start_time, end_time) + "SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s", + (scheme, start_time, end_time, node_id), ) return await cur.fetchall() @staticmethod - async def update_node_field(conn: AsyncConnection, time: datetime, scheme: str, node_id: str, field: str, value: Any): + async def get_nodes_by_scheme_and_time_range( + conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime + ) -> List[dict]: + async with conn.cursor() as cur: + await cur.execute( + "SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s", + (scheme, start_time, end_time), + ) + return await cur.fetchall() + + @staticmethod + async def get_node_field_by_scheme_and_time_range( + conn: AsyncConnection, + scheme: str, + start_time: datetime, + end_time: datetime, + node_id: str, + field: str, + ) -> Any: + # Validate field name to prevent SQL injection valid_fields = {"actual_demand", "total_head", "pressure", "quality"} if field not in valid_fields: raise ValueError(f"Invalid field: {field}") - query = sql.SQL("UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s").format(sql.Identifier(field)) - + query = sql.SQL( + "SELECT {} FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s" + ).format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (scheme, start_time, end_time, node_id)) + row = await cur.fetchone() + return row[field] if row else None + + @staticmethod + async def get_nodes_field_by_scheme_and_time_range( + conn: AsyncConnection, + scheme: str, + start_time: datetime, + end_time: datetime, + field: str, + ) -> Any: + # Validate field name to prevent SQL injection + valid_fields = {"actual_demand", "total_head", "pressure", "quality"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL( + "SELECT {} FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s" + ).format(sql.Identifier(field)) + + async with conn.cursor() as cur: + await cur.execute(query, (scheme, start_time, end_time)) + row = await cur.fetchone() + return row[field] if row else None + + @staticmethod + async def update_node_field( + conn: AsyncConnection, + time: datetime, + scheme: str, + node_id: str, + field: str, + value: Any, + ): + valid_fields = {"actual_demand", "total_head", "pressure", "quality"} + if field not in valid_fields: + raise ValueError(f"Invalid field: {field}") + + query = sql.SQL( + "UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s" + ).format(sql.Identifier(field)) + async with conn.cursor() as cur: await cur.execute(query, (value, time, scheme, node_id)) + + @staticmethod + async def delete_nodes_by_scheme_and_time_range( + conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime + ): + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s", + (scheme, start_time, end_time), + )