删除旧文件;更新数据库查询方法

This commit is contained in:
JIANG
2025-12-04 17:32:54 +08:00
parent d04d29cfb3
commit a855fa1f42
26 changed files with 904 additions and 94 deletions

11
main.py
View File

@@ -3123,6 +3123,8 @@ async def fastapi_query_cleaning_scada_data_by_device_id_and_time_range(
return influxdb_api.query_cleaning_SCADA_data_by_device_ID_and_timerange(
query_ids_list=query_ids, start_time=starttime, end_time=endtime
)
# 查询到的SCADA模拟数据从 realtime_simulation bucket 中查找)
@app.get("/querysimulationscadadatabydeviceidandtimerange/")
async def fastapi_query_simulation_scada_data_by_device_id_and_time_range(
@@ -3447,7 +3449,7 @@ async def fastapi_run_simulation_manually_by_date(
item["name"], region_result
)
)
(
globals.source_outflow_region_patterns,
globals.realtime_region_pipe_flow_and_demand_patterns,
@@ -4212,9 +4214,10 @@ if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)
# url='http://127.0.0.1:8000/valve_close_analysis?network=beibeizone&start_time=2024-04-01T08:00:00Z&valve_IDs=GSD2307192058577780A3287D78&valve_IDs=GSD2307192058572E953B707226(S2)&duration=1800'
# url='http://127.0.0.1:8000/burst_analysis?network=beibeizone&start_time=2024-04-01T08:00:00Z&burst_ID=ZBBGXSZW000001&duration=1800'
url = "http://192.168.1.36:8000/queryallschemeallrecords/?schemename=Fangan0817114448&querydate=2025-08-13&schemetype=burst_Analysis"
# url = "http://192.168.1.36:8000/queryallschemeallrecords/?schemename=Fangan0817114448&querydate=2025-08-13&schemetype=burst_Analysis"
# response = Request.get(url)
import requests
# import requests
response = requests.get(url)
# response = requests.get(url)
print(get_all_scada_info("szh"))

104
postgresql/database.py Normal file
View File

@@ -0,0 +1,104 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Optional
import psycopg_pool
from psycopg.rows import dict_row
import postgresql_info
# Configure logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_name=None):
self.pool = None
self.db_name = db_name
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
target_db_name = db_name or self.db_name
conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name)
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
min_size=1,
max_size=20,
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}")
except Exception as e:
logger.error(f"Failed to initialize postgresql connection pool: {e}")
raise
async def open(self):
if self.pool:
await self.pool.open()
async def close(self):
"""Close the connection pool."""
if self.pool:
await self.pool.close()
logger.info("PostgreSQL connection pool closed.")
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
"""Get a connection from the pool."""
if not self.pool:
raise Exception("Database pool is not initialized.")
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
return db # 返回默认数据库实例
if db_name not in _database_instances:
# 创建新的数据库实例
instance = create_database_instance(db_name)
instance.init_pool()
await instance.open()
_database_instances[db_name] = instance
logger.info(f"Created new database instance for: {db_name}")
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():
await instance.close()
logger.info(f"Closed database instance for: {db_name}")
_database_instances.clear()
# 关闭默认数据库
await db.close()
logger.info("All database instances cleaned up.")

40
postgresql/router.py Normal file
View File

@@ -0,0 +1,40 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional
from datetime import datetime
from psycopg import AsyncConnection
from .database import get_database_instance
from .scada_info import query_pg_scada_info
router = APIRouter(prefix="/postgresql", tags=["postgresql"])
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
@router.get("/scada-info")
async def get_scada_info_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询SCADA信息
"""
try:
# 使用连接查询SCADA信息
async with conn.cursor() as cur:
await cur.execute("SELECT * FROM scada_info")
scada_data = await cur.fetchall()
return {"success": True, "data": scada_data, "count": len(scada_data)}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"查询SCADA信息时发生错误: {str(e)}"
)

35
postgresql/scada_info.py Normal file
View File

@@ -0,0 +1,35 @@
from typing import List, Optional, Any
from psycopg import AsyncConnection
class ScadaRepository:
@staticmethod
async def get_scadas_info(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中,scada_info 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表,每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, transmission_mode, transmission_frequency, reliability
FROM public.scada_info
"""
)
records = await cur.fetchall()
# 将查询结果转换为字典列表
records_list = []
for record in records:
record_dict = {
"id": record[0],
"type": record[1],
"transmission_mode": record[2],
"transmission_frequency": record[3],
"reliability": record[4],
}
records_list.append(record_dict)
return records_list

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,3 @@
from .router import router
from .database import *
from .timescaledb_info import *

View File

@@ -1,6 +1,6 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from typing import AsyncGenerator, Dict, Optional
import psycopg_pool
from psycopg.rows import dict_row
import timescaledb.timescaledb_info as timescaledb_info
@@ -9,12 +9,15 @@ import timescaledb.timescaledb_info as timescaledb_info
logger = logging.getLogger(__name__)
class Database:
def __init__(self):
def __init__(self, db_name=None):
self.pool = None
self.db_name = db_name
def init_pool(self):
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
conn_string = timescaledb_info.get_pgconn_string()
# Use provided db_name, or the one from constructor, or default from config
target_db_name = db_name or self.db_name
conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name)
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
@@ -23,7 +26,7 @@ class Database:
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row} # Return rows as dictionaries
)
logger.info("TimescaleDB connection pool initialized.")
logger.info(f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}")
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
raise
@@ -47,9 +50,53 @@ class Database:
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
return db # 返回默认数据库实例
if db_name not in _database_instances:
# 创建新的数据库实例
instance = create_database_instance(db_name)
instance.init_pool()
await instance.open()
_database_instances[db_name] = instance
logger.info(f"Created new database instance for: {db_name}")
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():
await instance.close()
logger.info(f"Closed database instance for: {db_name}")
_database_instances.clear()
# 关闭默认数据库
await db.close()
logger.info("All database instances cleaned up.")

View File

@@ -1,40 +1,65 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Any, Dict
from typing import List, Optional
from datetime import datetime
from psycopg import AsyncConnection
from .database import get_db_connection
from .database import get_database_instance
from .schemas.realtime import RealtimeRepository
from .schemas.scheme import SchemeRepository
from .schemas.scada import ScadaRepository
router = APIRouter(prefix="/timescaledb", tags=["TimescaleDB"])
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
# --- Realtime Endpoints ---
@router.post("/realtime/links/batch", status_code=201)
async def insert_realtime_links(
data: List[dict],
conn: AsyncConnection = Depends(get_db_connection)
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await RealtimeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/realtime/links")
async def get_realtime_links(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_db_connection)
conn: AsyncConnection = Depends(get_database_connection),
):
return await RealtimeRepository.get_links_by_time(conn, start_time, end_time)
@router.delete("/realtime/links")
async def delete_realtime_links(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await RealtimeRepository.delete_links_by_time(conn, start_time, end_time)
return {"message": "Deleted successfully"}
@router.patch("/realtime/links/{link_id}/field")
async def update_realtime_link_field(
link_id: str,
time: datetime,
field: str,
value: float, # Assuming float for now, could be Any but FastAPI needs type
conn: AsyncConnection = Depends(get_db_connection)
value: float, # Assuming float for now, could be Any but FastAPI needs type
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await RealtimeRepository.update_link_field(conn, time, link_id, field, value)
@@ -42,64 +67,214 @@ async def update_realtime_link_field(
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/realtime/nodes/batch", status_code=201)
async def insert_realtime_nodes(
data: List[dict],
conn: AsyncConnection = Depends(get_db_connection)
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await RealtimeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/realtime/nodes")
async def get_realtime_nodes(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_db_connection)
conn: AsyncConnection = Depends(get_database_connection),
):
return await RealtimeRepository.get_nodes_by_time(conn, start_time, end_time)
@router.delete("/realtime/nodes")
async def delete_realtime_nodes(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await RealtimeRepository.delete_nodes_by_time(conn, start_time, end_time)
return {"message": "Deleted successfully"}
# --- Scheme Endpoints ---
@router.post("/scheme/links/batch", status_code=201)
async def insert_scheme_links(
data: List[dict],
conn: AsyncConnection = Depends(get_db_connection)
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await SchemeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/links")
async def get_scheme_links(
scheme: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_db_connection)
conn: AsyncConnection = Depends(get_database_connection),
):
return await SchemeRepository.get_links_by_scheme_and_time(conn, scheme, start_time, end_time)
return await SchemeRepository.get_links_by_scheme_and_time(
conn, scheme, start_time, end_time
)
@router.get("/scheme/links/{link_id}/field")
async def get_scheme_link_field(
scheme: str,
link_id: str,
time: datetime,
field: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
return await SchemeRepository.get_link_field_by_scheme_and_time(
conn, time, scheme, link_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/links/{link_id}/field")
async def update_scheme_link_field(
scheme: str,
link_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await SchemeRepository.update_link_field(
conn, time, scheme, link_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/links")
async def delete_scheme_links(
scheme: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await SchemeRepository.delete_links_by_scheme_and_time(
conn, scheme, start_time, end_time
)
return {"message": "Deleted successfully"}
@router.post("/scheme/nodes/batch", status_code=201)
async def insert_scheme_nodes(
data: List[dict],
conn: AsyncConnection = Depends(get_db_connection)
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await SchemeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/nodes/{node_id}/field")
async def get_scheme_node_field(
scheme: str,
node_id: str,
time: datetime,
field: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
return await SchemeRepository.get_node_field_by_scheme_and_time(
conn, time, scheme, node_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/nodes/{node_id}/field")
async def update_scheme_node_field(
scheme: str,
node_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await SchemeRepository.update_node_field(
conn, time, scheme, node_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/nodes")
async def delete_scheme_nodes(
scheme: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await SchemeRepository.delete_nodes_by_scheme_and_time(
conn, scheme, start_time, end_time
)
return {"message": "Deleted successfully"}
# --- SCADA Endpoints ---
@router.post("/scada/batch", status_code=201)
async def insert_scada_data(
data: List[dict],
conn: AsyncConnection = Depends(get_db_connection)
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await ScadaRepository.insert_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scada")
async def get_scada_data(
device_id: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_db_connection)
conn: AsyncConnection = Depends(get_database_connection),
):
return await ScadaRepository.get_data_by_time(conn, device_id, start_time, end_time)
@router.get("/scada/{device_id}/field")
async def get_scada_field(
device_id: str,
time: datetime,
field: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
return await ScadaRepository.get_field(conn, time, device_id, field)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scada/{device_id}/field")
async def update_scada_field(
device_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await ScadaRepository.update_field(conn, time, device_id, field, value)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scada")
async def delete_scada_data(
device_id: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await ScadaRepository.delete_data_by_time(conn, device_id, start_time, end_time)
return {"message": "Deleted successfully"}

View File

@@ -2,10 +2,11 @@ from typing import List, Any, Optional
from datetime import datetime
from psycopg import AsyncConnection, sql
class RealtimeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for realtime.link_simulation using COPY for performance."""
@@ -17,51 +18,141 @@ class RealtimeRepository:
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row((
item['time'], item['id'], item.get('flow'), item.get('friction'), item.get('headloss'),
item.get('quality'), item.get('reaction'), item.get('setting'), item.get('status'), item.get('velocity')
))
await copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_links_by_time(conn: AsyncConnection, start_time: datetime, end_time: datetime) -> List[dict]:
async def get_link_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, link_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time)
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_link_field(conn: AsyncConnection, time: datetime, link_id: str, field: str) -> Any:
async def get_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> Any:
# Validate field name to prevent SQL injection
valid_fields = {"flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity"}
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL("SELECT {} FROM realtime.link_simulation WHERE time = %s AND id = %s").format(sql.Identifier(field))
query = sql.SQL(
"SELECT {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (time, link_id))
await cur.execute(query, (start_time, end_time, link_id))
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def update_link_field(conn: AsyncConnection, time: datetime, link_id: str, field: str, value: Any):
valid_fields = {"flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity"}
async def get_links_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> Any:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL("UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s").format(sql.Identifier(field))
query = sql.SQL(
"SELECT {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, link_id))
@staticmethod
async def delete_links_by_time(conn: AsyncConnection, start_time: datetime, end_time: datetime):
async def delete_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time)
(start_time, end_time),
)
# --- Node Simulation ---
@@ -76,39 +167,102 @@ class RealtimeRepository:
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row((
item['time'], item['id'], item.get('actual_demand'), item.get('total_head'), item.get('pressure'), item.get('quality')
))
await copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_nodes_by_time(conn: AsyncConnection, start_time: datetime, end_time: datetime) -> List[dict]:
async def get_node_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, node_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time)
(start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field(conn: AsyncConnection, time: datetime, node_id: str, field: str) -> Any:
async def get_node_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> Any:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL("SELECT {} FROM realtime.node_simulation WHERE time = %s AND id = %s").format(sql.Identifier(field))
query = sql.SQL(
"SELECT {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (time, node_id))
await cur.execute(query, (start_time, end_time, node_id))
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def update_node_field(conn: AsyncConnection, time: datetime, node_id: str, field: str, value: Any):
async def get_nodes_field_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, field: str
) -> Any:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL("UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s").format(sql.Identifier(field))
query = sql.SQL(
"SELECT {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, node_id))
@staticmethod
async def delete_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)

View File

@@ -2,10 +2,11 @@ from typing import List, Any
from datetime import datetime
from psycopg import AsyncConnection, sql
class ScadaRepository:
@staticmethod
async def insert_batch(conn: AsyncConnection, data: List[dict]):
async def insert_scada_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
@@ -14,26 +15,64 @@ class ScadaRepository:
"COPY scada.scada_data (time, device_id, monitored_value, cleaned_value) FROM STDIN"
) as copy:
for item in data:
await copy.write_row((
item['time'], item['device_id'], item.get('monitored_value'), item.get('cleaned_value')
))
await copy.write_row(
(
item["time"],
item["device_id"],
item.get("monitored_value"),
item.get("cleaned_value"),
)
)
@staticmethod
async def get_data_by_time(conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime) -> List[dict]:
async def get_scada_by_id_time(
conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s",
(device_id, start_time, end_time)
(device_id, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def update_field(conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any):
async def get_scada_field_by_id_time(
conn: AsyncConnection, time: datetime, device_id: str, field: str
) -> Any:
valid_fields = {"monitored_value", "cleaned_value"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL("UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s").format(sql.Identifier(field))
query = sql.SQL(
"SELECT {} FROM scada.scada_data WHERE time = %s AND device_id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (time, device_id))
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def update_scada_field(
conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any
):
valid_fields = {"monitored_value", "cleaned_value"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, device_id))
@staticmethod
async def delete_scada_by_id_time(
conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s",
(device_id, start_time, end_time),
)

View File

@@ -1,13 +1,15 @@
from typing import List, Any
from typing import List, Any, Optional
from datetime import datetime
from psycopg import AsyncConnection, sql
class SchemeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for scheme.link_simulation using COPY for performance."""
if not data:
return
@@ -16,31 +18,150 @@ class SchemeRepository:
"COPY scheme.link_simulation (time, scheme, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row((
item['time'], item['scheme'], item['id'], item.get('flow'), item.get('friction'), item.get('headloss'),
item.get('quality'), item.get('reaction'), item.get('setting'), item.get('status'), item.get('velocity')
))
await copy.write_row(
(
item["time"],
item["scheme"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_links_by_scheme_and_time(conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime) -> List[dict]:
async def get_link_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
start_time: datetime,
end_time: datetime,
link_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time)
"SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s",
(scheme, start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def update_link_field(conn: AsyncConnection, time: datetime, scheme: str, link_id: str, field: str, value: Any):
valid_fields = {"flow", "friction", "headloss", "quality", "reaction", "setting", "status", "velocity"}
async def get_links_by_scheme_and_time_range(
conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> Any:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL("UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s").format(sql.Identifier(field))
query = sql.SQL(
"SELECT {} FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme, start_time, end_time, link_id))
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def get_links_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> Any:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT {} FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme, start_time, end_time))
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
scheme: str,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme, link_id))
@staticmethod
async def delete_links_by_scheme_and_time_range(
conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time),
)
# --- Node Simulation ---
@staticmethod
@@ -53,26 +174,115 @@ class SchemeRepository:
"COPY scheme.node_simulation (time, scheme, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row((
item['time'], item['scheme'], item['id'], item.get('actual_demand'), item.get('total_head'), item.get('pressure'), item.get('quality')
))
await copy.write_row(
(
item["time"],
item["scheme"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_nodes_by_scheme_and_time(conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime) -> List[dict]:
async def get_node_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
start_time: datetime,
end_time: datetime,
node_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time)
"SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s",
(scheme, start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def update_node_field(conn: AsyncConnection, time: datetime, scheme: str, node_id: str, field: str, value: Any):
async def get_nodes_by_scheme_and_time_range(
conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> Any:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL("UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s").format(sql.Identifier(field))
query = sql.SQL(
"SELECT {} FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme, start_time, end_time, node_id))
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def get_nodes_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> Any:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT {} FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme, start_time, end_time))
row = await cur.fetchone()
return row[field] if row else None
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
scheme: str,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme, node_id))
@staticmethod
async def delete_nodes_by_scheme_and_time_range(
conn: AsyncConnection, scheme: str, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE scheme = %s AND time >= %s AND time <= %s",
(scheme, start_time, end_time),
)