Files
TJWaterServerBinary/app/infra/db/timescaledb/schemas/scheme.py

761 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
import globals
# 定义UTC+8时区
UTC_8 = timezone(timedelta(hours=8))
class SchemeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance."""
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance (sync version)."""
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
cur.execute(
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_links_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, link_id)
)
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_links_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
scheme_type: str,
scheme_name: str,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme_type, scheme_name, link_id))
@staticmethod
async def delete_links_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- Node Simulation ---
@staticmethod
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
cur.execute(
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, node_id)
)
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_nodes_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
scheme_type: str,
scheme_name: str,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme_type, scheme_name, node_id))
@staticmethod
async def delete_nodes_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- 复合查询 ---
@staticmethod
async def store_scheme_simulation_result(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
await SchemeRepository.insert_nodes_batch(conn, node_data)
if link_data:
await SchemeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_scheme_simulation_result_sync(
conn: Connection,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
SchemeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
SchemeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_scheme_time_property(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
query_time: str,
type: str,
property: str,
) -> list:
"""
Query all records by scheme, time and property from TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
query_time: Time to query (ISO format string)
type: Type of data ("node" or "link")
property: Property/field to query
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
data = await SchemeRepository.get_nodes_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, property
)
elif type.lower() == "link":
data = await SchemeRepository.get_links_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, property
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
# Format the results
# Format the results
result = []
for id, items in data.items():
for item in items:
result.append({"ID": id, "value": item["value"]})
return result
@staticmethod
async def query_scheme_simulation_result_by_id_time(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
id: str,
type: str,
query_time: str,
) -> list[dict]:
"""
Query scheme simulation results by id and time from TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
id: The id of the node or link
type: Type of data ("node" or "link")
query_time: Time to query (ISO format string)
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
return await SchemeRepository.get_node_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, id
)
elif type.lower() == "link":
return await SchemeRepository.get_link_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, id
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")