重构现代化 FastAPI 后端项目框架

This commit is contained in:
2026-01-21 16:50:57 +08:00
parent 9e06e68a15
commit c56f2fd1db
352 changed files with 176 additions and 70 deletions
View File
View File
View File
View File
View File
File diff suppressed because it is too large Load Diff
+6
View File
@@ -0,0 +1,6 @@
# influxdb数据库连接信息
url = "http://127.0.0.1:8086" # 替换为你的InfluxDB实例地址
token = "kMPX2V5HsbzPpUT2B9HPBu1sTG1Emf-lPlT2UjxYnGAuocpXq_f_0lK4HHs-TbbKyjsZpICkMsyXG_V2D7P7yQ==" # 替换为你的InfluxDB Token
# _ENCODED_TOKEN = "eEdETTVSWnFSSkF1ekFHUy1vdFhVZEMyTkZkWTc1cUpBalJMcUFCNHA1V2NJSUFsSVVwT3BUOF95QTE2QU9IbUpXZXJ3UV8wOGd3Yjg0c3k0MmpuWlE9PQ=="
# token = base64.b64decode(_ENCODED_TOKEN).decode("utf-8")
org = "TJWATERORG" # 替换为你的Organization名称
+33
View File
@@ -0,0 +1,33 @@
from influxdb_client import InfluxDBClient, Point, WriteOptions
from influxdb_client.client.query_api import QueryApi
import influxdb_info
# 配置 InfluxDB 连接
url = influxdb_info.url
token = influxdb_info.token
org = influxdb_info.org
bucket = "SCADA_data"
# 创建 InfluxDB 客户端
client = InfluxDBClient(url=url, token=token, org=org)
# 创建查询 API 对象
query_api = client.query_api()
# 构建查询语句
query = f'''
from(bucket: "{bucket}")
|> range(start: -1h)
'''
# 执行查询
result = query_api.query(query)
print(result)
# 处理查询结果
for table in result:
for record in table.records:
print(f"Time: {record.get_time()}, Value: {record.get_value()}, Measurement: {record.get_measurement()}, Field: {record.get_field()}")
# 关闭客户端连接
client.close()
+1
View File
@@ -0,0 +1 @@
from .router import router
+108
View File
@@ -0,0 +1,108 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Optional
import psycopg_pool
from psycopg.rows import dict_row
import app.native.api.postgresql_info as postgresql_info
# Configure logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_name=None):
self.pool = None
self.db_name = db_name
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = postgresql_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
min_size=5,
max_size=20,
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(f"PostgreSQL connection pool initialized for database: default")
except Exception as e:
logger.error(f"Failed to initialize postgresql connection pool: {e}")
raise
async def open(self):
if self.pool:
await self.pool.open()
async def close(self):
"""Close the connection pool."""
if self.pool:
await self.pool.close()
logger.info("PostgreSQL connection pool closed.")
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
"""Get a connection from the pool."""
if not self.pool:
raise Exception("Database pool is not initialized.")
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
return db # 返回默认数据库实例
if db_name not in _database_instances:
# 创建新的数据库实例
instance = create_database_instance(db_name)
instance.init_pool()
await instance.open()
_database_instances[db_name] = instance
logger.info(f"Created new database instance for: {db_name}")
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():
await instance.close()
logger.info(f"Closed database instance for: {db_name}")
_database_instances.clear()
# 关闭默认数据库
await db.close()
logger.info("All database instances cleaned up.")
@@ -0,0 +1,83 @@
import time
from typing import List, Optional
from fastapi.logger import logger
import postgresql_info
import psycopg
class InternalQueries:
@staticmethod
def get_links_by_property(
fields: Optional[List[str]] = None,
property_conditions: Optional[dict] = None,
db_name: str = None,
max_retries: int = 3,
) -> List[dict]:
"""
查询pg数据库中,pipes 的指定字段记录或根据属性筛选
:param fields: 要查询的字段列表,如 ["id", "diameter", "status"],默认查询所有字段
:param property: 可选的筛选条件字典,如 {"status": "Open"} 或 {"diameter": 300}
:param db_name: 数据库名称
:param max_retries: 最大重试次数
:return: 包含所有记录的列表,每条记录为一个字典
"""
# 如果未指定字段,查询所有字段
if not fields:
fields = [
"id",
"node1",
"node2",
"length",
"diameter",
"roughness",
"minor_loss",
"status",
]
for attempt in range(max_retries):
try:
conn_string = (
postgresql_info.get_pgconn_string(db_name=db_name)
if db_name
else postgresql_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
with conn.cursor() as cur:
# 构建SELECT子句
select_fields = ", ".join(fields)
base_query = f"""
SELECT {select_fields}
FROM public.pipes
"""
# 如果提供了筛选条件,构建WHERE子句
if property_conditions:
conditions = []
params = []
for key, value in property_conditions.items():
conditions.append(f"{key} = %s")
params.append(value)
query = base_query + " WHERE " + " AND ".join(conditions)
cur.execute(query, params)
else:
cur.execute(base_query)
records = cur.fetchall()
# 将查询结果转换为字典列表
pipes = []
for record in records:
pipe_dict = {}
for idx, field in enumerate(fields):
pipe_dict[field] = record[idx]
pipes.append(pipe_dict)
return pipes
break # 成功
except Exception as e:
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1)
else:
raise
+90
View File
@@ -0,0 +1,90 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional
from psycopg import AsyncConnection
from .database import get_database_instance
from .scada_info import ScadaRepository
from .scheme import SchemeRepository
router = APIRouter(prefix="/postgresql", tags=["postgresql"])
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
@router.get("/scada-info")
async def get_scada_info_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有SCADA信息
"""
try:
# 使用ScadaRepository查询SCADA信息
scada_data = await ScadaRepository.get_scadas(conn)
return {"success": True, "data": scada_data, "count": len(scada_data)}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"查询SCADA信息时发生错误: {str(e)}"
)
@router.get("/scheme-list")
async def get_scheme_list_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有方案信息
"""
try:
# 使用SchemeRepository查询方案信息
scheme_data = await SchemeRepository.get_schemes(conn)
return {"success": True, "data": scheme_data, "count": len(scheme_data)}
except Exception as e:
raise HTTPException(status_code=500, detail=f"查询方案信息时发生错误: {str(e)}")
@router.get("/burst-locate-result")
async def get_burst_locate_result_with_connection(
conn: AsyncConnection = Depends(get_database_connection),
):
"""
使用连接池查询所有爆管定位结果
"""
try:
# 使用SchemeRepository查询爆管定位结果
burst_data = await SchemeRepository.get_burst_locate_results(conn)
return {"success": True, "data": burst_data, "count": len(burst_data)}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"查询爆管定位结果时发生错误: {str(e)}"
)
@router.get("/burst-locate-result/{burst_incident}")
async def get_burst_locate_result_by_incident(
burst_incident: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""
根据 burst_incident 查询爆管定位结果
"""
try:
# 使用SchemeRepository查询爆管定位结果
return await SchemeRepository.get_burst_locate_result_by_incident(
conn, burst_incident
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"根据 burst_incident 查询爆管定位结果时发生错误: {str(e)}",
)
+36
View File
@@ -0,0 +1,36 @@
from typing import List, Optional, Any
from psycopg import AsyncConnection
class ScadaRepository:
@staticmethod
async def get_scadas(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中,scada_info 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表,每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, associated_element_id, transmission_mode, transmission_frequency, reliability
FROM public.scada_info
"""
)
records = await cur.fetchall()
# 将查询结果转换为字典列表(假设 record 是字典)
scada_infos = []
for record in records:
scada_infos.append(
{
"id": record["id"], # 使用字典键
"type": record["type"],
"associated_element_id": record["associated_element_id"],
"transmission_mode": record["transmission_mode"],
"transmission_frequency": record["transmission_frequency"],
"reliability": record["reliability"],
}
)
return scada_infos
+104
View File
@@ -0,0 +1,104 @@
from typing import List, Optional, Any
from psycopg import AsyncConnection
class SchemeRepository:
@staticmethod
async def get_schemes(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中, scheme_list 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表, 每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT scheme_id, scheme_name, scheme_type, username, create_time, scheme_start_time, scheme_detail
FROM public.scheme_list
"""
)
records = await cur.fetchall()
scheme_list = []
for record in records:
scheme_list.append(
{
"scheme_id": record["scheme_id"],
"scheme_name": record["scheme_name"],
"scheme_type": record["scheme_type"],
"username": record["username"],
"create_time": record["create_time"],
"scheme_start_time": record["scheme_start_time"],
"scheme_detail": record["scheme_detail"],
}
)
return scheme_list
@staticmethod
async def get_burst_locate_results(conn: AsyncConnection) -> List[dict]:
"""
查询pg数据库中, burst_locate_result 的所有记录
:param conn: 异步数据库连接
:return: 包含所有记录的列表, 每条记录为一个字典
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, burst_incident, leakage, detect_time, locate_result
FROM public.burst_locate_result
"""
)
records = await cur.fetchall()
results = []
for record in records:
results.append(
{
"id": record["id"],
"type": record["type"],
"burst_incident": record["burst_incident"],
"leakage": record["leakage"],
"detect_time": record["detect_time"],
"locate_result": record["locate_result"],
}
)
return results
@staticmethod
async def get_burst_locate_result_by_incident(
conn: AsyncConnection, burst_incident: str
) -> List[dict]:
"""
根据 burst_incident 查询爆管定位结果
:param conn: 异步数据库连接
:param burst_incident: 爆管事件标识
:return: 包含匹配记录的列表
"""
async with conn.cursor() as cur:
await cur.execute(
"""
SELECT id, type, burst_incident, leakage, detect_time, locate_result
FROM public.burst_locate_result
WHERE burst_incident = %s
""",
(burst_incident,),
)
records = await cur.fetchall()
results = []
for record in records:
results.append(
{
"id": record["id"],
"type": record["type"],
"burst_incident": record["burst_incident"],
"leakage": record["leakage"],
"detect_time": record["detect_time"],
"locate_result": record["locate_result"],
}
)
return results
+4
View File
@@ -0,0 +1,4 @@
from .router import router
from .database import *
from .timescaledb_info import *
from .composite_queries import CompositeQueries
@@ -0,0 +1,606 @@
import time
from typing import List, Optional, Any, Dict, Tuple
from datetime import datetime, timedelta
from psycopg import AsyncConnection
import pandas as pd
import numpy as np
from api_ex.Fdataclean import clean_flow_data_df_kf
from api_ex.Pdataclean import clean_pressure_data_df_km
from api_ex.pipeline_health_analyzer import PipelineHealthAnalyzer
from postgresql.internal_queries import InternalQueries
from postgresql.scada_info import ScadaRepository as PostgreScadaRepository
from timescaledb.schemas.realtime import RealtimeRepository
from timescaledb.schemas.scheme import SchemeRepository
from timescaledb.schemas.scada import ScadaRepository
class CompositeQueries:
"""
复合查询类,提供跨表查询功能
"""
@staticmethod
async def get_scada_associated_realtime_simulation_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 SCADA 关联的 link/node 模拟值
根据传入的 SCADA device_ids,找到关联的 link/node
并根据对应的 type,查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
postgres_conn: PostgreSQL 异步连接
device_ids: SCADA 设备ID列表
start_time: 开始时间
end_time: 结束时间
Returns:
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
result = {}
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
for device_id in device_ids:
# 2. 根据 device_id 找到对应的 SCADA 信息
target_scada = None
for scada in scada_infos:
if scada["id"] == device_id:
target_scada = scada
break
if not target_scada:
raise ValueError(f"SCADA device {device_id} not found")
# 3. 根据 type 和 associated_element_id 查询对应的模拟数据
element_id = target_scada["associated_element_id"]
scada_type = target_scada["type"]
if scada_type.lower() == "pipe_flow":
# 查询 link 模拟数据
res = await RealtimeRepository.get_link_field_by_time_range(
timescale_conn, start_time, end_time, element_id, "flow"
)
elif scada_type.lower() == "pressure":
# 查询 node 模拟数据
res = await RealtimeRepository.get_node_field_by_time_range(
timescale_conn, start_time, end_time, element_id, "pressure"
)
else:
raise ValueError(f"Unknown SCADA type: {scada_type}")
# 添加 scada_id 到每个数据项
for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod
async def get_scada_associated_scheme_simulation_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
scheme_type: str,
scheme_name: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 SCADA 关联的 link/node scheme 模拟值
根据传入的 SCADA device_ids,找到关联的 link/node
并根据对应的 type,查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
postgres_conn: PostgreSQL 异步连接
device_ids: SCADA 设备ID列表
start_time: 开始时间
end_time: 结束时间
Returns:
模拟数据字典,以 device_id 为键,值为数据列表,每个数据包含 time, value 和 scada_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
result = {}
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
for device_id in device_ids:
# 2. 根据 device_id 找到对应的 SCADA 信息
target_scada = None
for scada in scada_infos:
if scada["id"] == device_id:
target_scada = scada
break
if not target_scada:
raise ValueError(f"SCADA device {device_id} not found")
# 3. 根据 type 和 associated_element_id 查询对应的模拟数据
element_id = target_scada["associated_element_id"]
scada_type = target_scada["type"]
if scada_type.lower() == "pipe_flow":
# 查询 link 模拟数据
res = await SchemeRepository.get_link_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
element_id,
"flow",
)
elif scada_type.lower() == "pressure":
# 查询 node 模拟数据
res = await SchemeRepository.get_node_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
element_id,
"pressure",
)
else:
raise ValueError(f"Unknown SCADA type: {scada_type}")
# 添加 scada_id 到每个数据项
for item in res:
item["scada_id"] = device_id
result[device_id] = res
return result
@staticmethod
async def get_realtime_simulation_data(
timescale_conn: AsyncConnection,
featureInfos: List[Tuple[str, str]],
start_time: datetime,
end_time: datetime,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 link/node 模拟值
根据传入的 featureInfos,找到关联的 link/node
并根据对应的 type,查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
featureInfos: 传入的 feature 信息列表,包含 (element_id, type)
start_time: 开始时间
end_time: 结束时间
Returns:
模拟数据字典,以 feature_id 为键,值为数据列表,每个数据包含 time, value 和 feature_id
Raises:
ValueError: 当 SCADA 设备未找到或字段无效时
"""
result = {}
for feature_id, type in featureInfos:
if type.lower() == "pipe":
# 查询 link 模拟数据
res = await RealtimeRepository.get_link_field_by_time_range(
timescale_conn, start_time, end_time, feature_id, "flow"
)
elif type.lower() == "junction":
# 查询 node 模拟数据
res = await RealtimeRepository.get_node_field_by_time_range(
timescale_conn, start_time, end_time, feature_id, "pressure"
)
else:
raise ValueError(f"Unknown type: {type}")
# 添加 scada_id 到每个数据项
for item in res:
item["feature_id"] = feature_id
result[feature_id] = res
return result
@staticmethod
async def get_scheme_simulation_data(
timescale_conn: AsyncConnection,
featureInfos: List[Tuple[str, str]],
start_time: datetime,
end_time: datetime,
scheme_type: str,
scheme_name: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
获取 link/node scheme 模拟值
根据传入的 featureInfos,找到关联的 link/node
并根据对应的 type,查询对应的模拟数据
Args:
timescale_conn: TimescaleDB 异步连接
featureInfos: 传入的 feature 信息列表,包含 (element_id, type)
start_time: 开始时间
end_time: 结束时间
scheme_type: 工况类型
scheme_name: 工况名称
Returns:
模拟数据字典,以 feature_id 为键,值为数据列表,每个数据包含 time, value 和 feature_id
Raises:
ValueError: 当类型无效时
"""
result = {}
for feature_id, type in featureInfos:
if type.lower() == "pipe":
# 查询 link 模拟数据
res = await SchemeRepository.get_link_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
feature_id,
"flow",
)
elif type.lower() == "junction":
# 查询 node 模拟数据
res = await SchemeRepository.get_node_field_by_scheme_and_time_range(
timescale_conn,
scheme_type,
scheme_name,
start_time,
end_time,
feature_id,
"pressure",
)
else:
raise ValueError(f"Unknown type: {type}")
# 添加 feature_id 到每个数据项
for item in res:
item["feature_id"] = feature_id
result[feature_id] = res
return result
@staticmethod
async def get_element_associated_scada_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
element_id: str,
start_time: datetime,
end_time: datetime,
use_cleaned: bool = False,
) -> Optional[Any]:
"""
获取 link/node 关联的 SCADA 监测值
根据传入的 link/node id,匹配 SCADA 信息,
如果存在关联的 SCADA device_id,获取实际的监测数据
Args:
timescale_conn: TimescaleDB 异步连接
postgres_conn: PostgreSQL 异步连接
element_id: link 或 node 的 ID
start_time: 开始时间
end_time: 结束时间
use_cleaned: 是否使用清洗后的数据 (True: "cleaned_value", False: "monitored_value")
Returns:
SCADA 监测数据值,如果没有找到则返回 None
Raises:
ValueError: 当元素类型无效时
"""
# 1. 查询所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
# 2. 根据 element_type 和 element_id 找到关联的 SCADA 设备
associated_scada = None
for scada in scada_infos:
if scada["associated_element_id"] == element_id:
associated_scada = scada
break
if not associated_scada:
# 没有找到关联的 SCADA 设备
return None
# 3. 通过 SCADA device_id 获取监测数据
device_id = associated_scada["id"]
# 根据 use_cleaned 参数选择字段
data_field = "cleaned_value" if use_cleaned else "monitored_value"
# 保证 device_id 以列表形式传递
res = await ScadaRepository.get_scada_field_by_id_time_range(
timescale_conn, [device_id], start_time, end_time, data_field
)
# 将 device_id 替换为 element_id 返回
return {element_id: res.get(device_id, [])}
@staticmethod
async def clean_scada_data(
timescale_conn: AsyncConnection,
postgres_conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> str:
"""
清洗 SCADA 数据
根据 device_ids 查询 monitored_value,清洗后更新 cleaned_value
Args:
timescale_conn: TimescaleDB 连接
postgres_conn: PostgreSQL 连接
device_ids: 设备 ID 列表
start_time: 开始时间
end_time: 结束时间
Returns:
"success" 或错误信息
"""
try:
# 获取所有 SCADA 信息
scada_infos = await PostgreScadaRepository.get_scadas(postgres_conn)
# 将列表转换为字典,以 device_id 为键
scada_device_info_dict = {info["id"]: info for info in scada_infos}
# 如果 device_ids 为空,则处理所有 SCADA 设备
if not device_ids:
device_ids = list(scada_device_info_dict.keys())
# 批量查询所有设备的数据
data = await ScadaRepository.get_scada_field_by_id_time_range(
timescale_conn, device_ids, start_time, end_time, "monitored_value"
)
if not data:
return "error: fetch none scada data" # 没有数据,直接返回
# 将嵌套字典转换为 DataFrame,使用 time 作为索引
# data 格式: {device_id: [{"time": "...", "value": ...}, ...]}
all_records = []
for device_id, records in data.items():
for record in records:
all_records.append(
{
"time": record["time"],
"device_id": device_id,
"value": record["value"],
}
)
if not all_records:
return "error: fetch none scada data" # 没有数据,直接返回
# 创建 DataFrame 并透视,使 device_id 成为列
df_long = pd.DataFrame(all_records)
df = df_long.pivot(index="time", columns="device_id", values="value")
# 根据type分类设备
pressure_ids = [
id
for id in df.columns
if scada_device_info_dict.get(id, {}).get("type") == "pressure"
]
flow_ids = [
id
for id in df.columns
if scada_device_info_dict.get(id, {}).get("type") == "pipe_flow"
]
# 处理pressure数据
if pressure_ids:
pressure_df = df[pressure_ids]
# 重置索引,将 time 变为普通列
pressure_df = pressure_df.reset_index()
# 移除 time 列,准备输入给清洗方法
value_df = pressure_df.drop(columns=["time"])
# 调用清洗方法
cleaned_value_df = clean_pressure_data_df_km(value_df)
# 添加 time 列到首列
cleaned_df = pd.concat([pressure_df["time"], cleaned_value_df], axis=1)
# 将清洗后的数据写回数据库
for device_id in pressure_ids:
if device_id in cleaned_df.columns:
cleaned_values = cleaned_df[device_id].tolist()
time_values = cleaned_df["time"].tolist()
for i, time_str in enumerate(time_values):
time_dt = datetime.fromisoformat(time_str)
value = cleaned_values[i]
await ScadaRepository.update_scada_field(
timescale_conn,
time_dt,
device_id,
"cleaned_value",
value,
)
# 处理flow数据
if flow_ids:
flow_df = df[flow_ids]
# 重置索引,将 time 变为普通列
flow_df = flow_df.reset_index()
# 移除 time 列,准备输入给清洗方法
value_df = flow_df.drop(columns=["time"])
# 调用清洗方法
cleaned_value_df = clean_flow_data_df_kf(value_df)
# 添加 time 列到首列
cleaned_df = pd.concat([flow_df["time"], cleaned_value_df], axis=1)
# 将清洗后的数据写回数据库
for device_id in flow_ids:
if device_id in cleaned_df.columns:
cleaned_values = cleaned_df[device_id].tolist()
time_values = cleaned_df["time"].tolist()
for i, time_str in enumerate(time_values):
time_dt = datetime.fromisoformat(time_str)
value = cleaned_values[i]
await ScadaRepository.update_scada_field(
timescale_conn,
time_dt,
device_id,
"cleaned_value",
value,
)
return "success"
except Exception as e:
return f"error: {str(e)}"
@staticmethod
async def predict_pipeline_health(
timescale_conn: AsyncConnection,
network_name: str,
query_time: datetime,
) -> List[Dict[str, Any]]:
"""
预测管道健康状况
根据管网名称和当前时间,查询管道信息和实时数据,
使用随机生存森林模型预测管道的生存概率
Args:
timescale_conn: TimescaleDB 异步连接
db_name: 管网数据库名称
query_time: 查询时间
property_conditions: 可选的管道筛选条件,如 {"diameter": 300}
Returns:
预测结果列表,每个元素包含 link_id 和对应的生存函数
Raises:
ValueError: 当参数无效或数据不足时
FileNotFoundError: 当模型文件未找到时
"""
try:
# 1. 准备时间范围(查询时间前后1秒)
start_time = query_time - timedelta(seconds=1)
end_time = query_time + timedelta(seconds=1)
# 2. 先查询流速数据(velocity),获取有数据的管道ID列表
velocity_data = await RealtimeRepository.get_links_field_by_time_range(
timescale_conn, start_time, end_time, "velocity"
)
if not velocity_data:
raise ValueError("未找到流速数据")
# 3. 只查询有流速数据的管道的基本信息
valid_link_ids = list(velocity_data.keys())
# 批量查询这些管道的详细信息
fields = ["id", "diameter", "node1", "node2"]
all_links = InternalQueries.get_links_by_property(
fields=fields,
db_name=network_name,
)
# 转换为字典以快速查找
links_dict = {link["id"]: link for link in all_links}
# 获取所有需要查询的节点ID
node_ids = set()
for link_id in valid_link_ids:
if link_id in links_dict:
link = links_dict[link_id]
node_ids.add(link["node1"])
node_ids.add(link["node2"])
# 4. 批量查询压力数据(pressure)
pressure_data = await RealtimeRepository.get_nodes_field_by_time_range(
timescale_conn, start_time, end_time, "pressure"
)
# 5. 组合数据结构
materials = []
diameters = []
velocities = []
pressures = []
link_ids = []
for link_id in valid_link_ids:
# 跳过不在管道字典中的ID(如泵等其他元素)
if link_id not in links_dict:
continue
link = links_dict[link_id]
diameter = link["diameter"]
node1 = link["node1"]
node2 = link["node2"]
# 获取流速数据
velocity_values = velocity_data[link_id]
velocity = velocity_values[-1]["value"] if velocity_values else 0
# 获取node1和node2的压力数据,计算平均值
node1_pressure = 0
node2_pressure = 0
if node1 in pressure_data and pressure_data[node1]:
pressure_values = pressure_data[node1]
node1_pressure = (
pressure_values[-1]["value"] if pressure_values else 0
)
if node2 in pressure_data and pressure_data[node2]:
pressure_values = pressure_data[node2]
node2_pressure = (
pressure_values[-1]["value"] if pressure_values else 0
)
# 计算平均压力
avg_pressure = (node1_pressure + node2_pressure) / 2
# 添加到列表
link_ids.append(link_id)
materials.append(7) # 默认材料类型为7,可根据实际情况调整
diameters.append(diameter)
velocities.append(velocity)
pressures.append(avg_pressure)
if not link_ids:
raise ValueError("没有找到有效的管道数据用于预测")
# 6. 创建DataFrame
data = pd.DataFrame(
{
"Material": materials,
"Diameter": diameters,
"Flow Velocity": velocities,
"Pressure": pressures,
}
)
# 7. 使用PipelineHealthAnalyzer进行预测
analyzer = PipelineHealthAnalyzer(
model_path="api_ex/model/my_survival_forest_model_quxi.joblib"
)
survival_functions = analyzer.predict_survival(data)
# 8. 组合结果
results = []
for i, link_id in enumerate(link_ids):
sf = survival_functions[i]
results.append(
{
"link_id": link_id,
"survival_function": {
"x": sf.x.tolist(), # 时间点(年)
"y": sf.y.tolist(), # 生存概率
},
}
)
return results
except Exception as e:
raise ValueError(f"管道健康预测失败: {str(e)}")
+115
View File
@@ -0,0 +1,115 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Optional
import psycopg_pool
from psycopg.rows import dict_row
import app.infra.db.timescaledb.timescaledb_info as timescaledb_info
# Configure logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_name=None):
self.pool = None
self.db_name = db_name
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
conn_string = timescaledb_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
min_size=5,
max_size=20,
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(
f"TimescaleDB connection pool initialized for database: default"
)
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
raise
async def open(self):
if self.pool:
await self.pool.open()
async def close(self):
"""Close the connection pool."""
if self.pool:
await self.pool.close()
logger.info("TimescaleDB connection pool closed.")
def get_pgconn_string(self, db_name=None):
"""Get the TimescaleDB connection string."""
target_db_name = db_name or self.db_name
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
"""Get a connection from the pool."""
if not self.pool:
raise Exception("Database pool is not initialized.")
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
return db # 返回默认数据库实例
if db_name not in _database_instances:
# 创建新的数据库实例
instance = create_database_instance(db_name)
instance.init_pool()
await instance.open()
_database_instances[db_name] = instance
logger.info(f"Created new database instance for: {db_name}")
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
使用方法: conn: AsyncConnection = Depends(lambda: get_database_connection("your_db_name"))
或在路由函数中: conn: AsyncConnection = Depends(get_database_connection)
"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():
await instance.close()
logger.info(f"Closed database instance for: {db_name}")
_database_instances.clear()
# 关闭默认数据库
await db.close()
logger.info("All database instances cleaned up.")
@@ -0,0 +1,122 @@
from typing import List
from fastapi.logger import logger
from timescaledb.schemas.scheme import SchemeRepository
from timescaledb.schemas.realtime import RealtimeRepository
import timescaledb.timescaledb_info as timescaledb_info
from datetime import datetime, timedelta
from timescaledb.schemas.scada import ScadaRepository
import psycopg
import time
class InternalStorage:
@staticmethod
def store_realtime_simulation(
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
db_name: str = None,
max_retries: int = 3,
):
"""存储实时模拟结果"""
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
RealtimeRepository.store_realtime_simulation_result_sync(
conn, node_result_list, link_result_list, result_start_time
)
break # 成功
except Exception as e:
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1) # 重试前等待
else:
raise # 达到最大重试次数后抛出异常
@staticmethod
def store_scheme_simulation(
scheme_type: str,
scheme_name: str,
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
num_periods: int = 1,
db_name: str = None,
max_retries: int = 3,
):
"""存储方案模拟结果"""
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
SchemeRepository.store_scheme_simulation_result_sync(
conn,
scheme_type,
scheme_name,
node_result_list,
link_result_list,
result_start_time,
num_periods,
)
break # 成功
except Exception as e:
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1) # 重试前等待
else:
raise # 达到最大重试次数后抛出异常
class InternalQueries:
@staticmethod
def query_scada_by_ids_time(
device_ids: List[str],
query_time: str,
db_name: str = None,
max_retries: int = 3,
) -> dict:
"""查询指定时间点的 SCADA 数据"""
# 解析时间,假设是北京时间
beijing_time = datetime.fromisoformat(query_time)
start_time = beijing_time - timedelta(seconds=1)
end_time = beijing_time + timedelta(seconds=1)
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
rows = ScadaRepository.get_scada_by_ids_time_range_sync(
conn, device_ids, start_time, end_time
)
# 处理结果,返回每个 device_id 的第一个值
result = {}
for device_id in device_ids:
device_rows = [
row for row in rows if row["device_id"] == device_id
]
if device_rows:
result[device_id] = device_rows[0]["monitored_value"]
else:
result[device_id] = None
return result
except Exception as e:
logger.error(f"查询尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1)
else:
raise
+627
View File
@@ -0,0 +1,627 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Optional
from datetime import datetime
from psycopg import AsyncConnection
from .database import get_database_instance
from .schemas.realtime import RealtimeRepository
from .schemas.scheme import SchemeRepository
from .schemas.scada import ScadaRepository
from .composite_queries import CompositeQueries
from postgresql.database import get_database_instance as get_postgres_database_instance
router = APIRouter(prefix="/timescaledb", tags=["TimescaleDB"])
# 创建支持数据库选择的连接依赖函数
async def get_database_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的数据库名称,为空时使用默认数据库"
)
):
"""获取数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
# PostgreSQL 数据库连接依赖函数
async def get_postgres_connection(
db_name: Optional[str] = Query(
None, description="指定要连接的 PostgreSQL 数据库名称,为空时使用默认数据库"
)
):
"""获取 PostgreSQL 数据库连接,支持通过查询参数指定数据库名称"""
instance = await get_postgres_database_instance(db_name)
async with instance.get_connection() as conn:
yield conn
# --- Realtime Endpoints ---
@router.post("/realtime/links/batch", status_code=201)
async def insert_realtime_links(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await RealtimeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/realtime/links")
async def get_realtime_links(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
return await RealtimeRepository.get_links_by_time_range(conn, start_time, end_time)
@router.delete("/realtime/links")
async def delete_realtime_links(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await RealtimeRepository.delete_links_by_time_range(conn, start_time, end_time)
return {"message": "Deleted successfully"}
@router.patch("/realtime/links/{link_id}/field")
async def update_realtime_link_field(
link_id: str,
time: datetime,
field: str,
value: float, # Assuming float for now, could be Any but FastAPI needs type
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await RealtimeRepository.update_link_field(conn, time, link_id, field, value)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/realtime/nodes/batch", status_code=201)
async def insert_realtime_nodes(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await RealtimeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/realtime/nodes")
async def get_realtime_nodes(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
return await RealtimeRepository.get_nodes_by_time_range(conn, start_time, end_time)
@router.delete("/realtime/nodes")
async def delete_realtime_nodes(
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await RealtimeRepository.delete_nodes_by_time_range(conn, start_time, end_time)
return {"message": "Deleted successfully"}
@router.post("/realtime/simulation/store", status_code=201)
async def store_realtime_simulation_result(
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Store realtime simulation results to TimescaleDB"""
await RealtimeRepository.store_realtime_simulation_result(
conn, node_result_list, link_result_list, result_start_time
)
return {"message": "Simulation results stored successfully"}
@router.get("/realtime/query/by-time-property")
async def query_realtime_records_by_time_property(
query_time: str,
type: str,
property: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query all realtime records by time and property"""
try:
results = await RealtimeRepository.query_all_record_by_time_property(
conn, query_time, type, property
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/realtime/query/by-id-time")
async def query_realtime_simulation_by_id_time(
id: str,
type: str,
query_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query realtime simulation results by id and time"""
try:
results = await RealtimeRepository.query_simulation_result_by_id_time(
conn, id, type, query_time
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# --- Scheme Endpoints ---
@router.post("/scheme/links/batch", status_code=201)
async def insert_scheme_links(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await SchemeRepository.insert_links_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/links")
async def get_scheme_links(
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
return await SchemeRepository.get_links_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
@router.get("/scheme/links/{link_id}/field")
async def get_scheme_link_field(
scheme_type: str,
scheme_name: str,
link_id: str,
start_time: datetime,
end_time: datetime,
field: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
return await SchemeRepository.get_link_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, link_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/links/{link_id}/field")
async def update_scheme_link_field(
scheme_type: str,
scheme_name: str,
link_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await SchemeRepository.update_link_field(
conn, time, scheme_type, scheme_name, link_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/links")
async def delete_scheme_links(
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await SchemeRepository.delete_links_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
return {"message": "Deleted successfully"}
@router.post("/scheme/nodes/batch", status_code=201)
async def insert_scheme_nodes(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await SchemeRepository.insert_nodes_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scheme/nodes/{node_id}/field")
async def get_scheme_node_field(
scheme_type: str,
scheme_name: str,
node_id: str,
start_time: datetime,
end_time: datetime,
field: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
return await SchemeRepository.get_node_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, node_id, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scheme/nodes/{node_id}/field")
async def update_scheme_node_field(
scheme_type: str,
scheme_name: str,
node_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await SchemeRepository.update_node_field(
conn, time, scheme_type, scheme_name, node_id, field, value
)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scheme/nodes")
async def delete_scheme_nodes(
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await SchemeRepository.delete_nodes_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time
)
return {"message": "Deleted successfully"}
@router.post("/scheme/simulation/store", status_code=201)
async def store_scheme_simulation_result(
scheme_type: str,
scheme_name: str,
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Store scheme simulation results to TimescaleDB"""
await SchemeRepository.store_scheme_simulation_result(
conn,
scheme_type,
scheme_name,
node_result_list,
link_result_list,
result_start_time,
)
return {"message": "Scheme simulation results stored successfully"}
@router.get("/scheme/query/by-scheme-time-property")
async def query_scheme_records_by_scheme_time_property(
scheme_type: str,
scheme_name: str,
query_time: str,
type: str,
property: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query all scheme records by scheme, time and property"""
try:
results = await SchemeRepository.query_all_record_by_scheme_time_property(
conn, scheme_type, scheme_name, query_time, type, property
)
return {"results": results}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/scheme/query/by-id-time")
async def query_scheme_simulation_by_id_time(
scheme_type: str,
scheme_name: str,
id: str,
type: str,
query_time: str,
conn: AsyncConnection = Depends(get_database_connection),
):
"""Query scheme simulation results by id and time"""
try:
result = await SchemeRepository.query_scheme_simulation_result_by_id_time(
conn, scheme_type, scheme_name, id, type, query_time
)
return {"result": result}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# --- SCADA Endpoints ---
@router.post("/scada/batch", status_code=201)
async def insert_scada_data(
data: List[dict], conn: AsyncConnection = Depends(get_database_connection)
):
await ScadaRepository.insert_scada_batch(conn, data)
return {"message": f"Inserted {len(data)} records"}
@router.get("/scada/by-ids-time-range")
async def get_scada_by_ids_time_range(
start_time: datetime,
end_time: datetime,
device_ids: str,
conn: AsyncConnection = Depends(get_database_connection),
):
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()] if device_ids else []
)
return await ScadaRepository.get_scada_by_ids_time_range(
conn, device_ids_list, start_time, end_time
)
@router.get("/scada/by-ids-field-time-range")
async def get_scada_field_by_ids_time_range(
start_time: datetime,
end_time: datetime,
field: str,
device_ids: str,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
return await ScadaRepository.get_scada_field_by_id_time_range(
conn, device_ids_list, start_time, end_time, field
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.patch("/scada/{device_id}/field")
async def update_scada_field(
device_id: str,
time: datetime,
field: str,
value: float,
conn: AsyncConnection = Depends(get_database_connection),
):
try:
await ScadaRepository.update_scada_field(conn, time, device_id, field, value)
return {"message": "Updated successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/scada/by-id-time-range")
async def delete_scada_data(
device_id: str,
start_time: datetime,
end_time: datetime,
conn: AsyncConnection = Depends(get_database_connection),
):
await ScadaRepository.delete_scada_by_id_time_range(
conn, device_id, start_time, end_time
)
return {"message": "Deleted successfully"}
# --- Composite Query Endpoints ---
@router.get("/composite/scada-simulation")
async def get_scada_associated_simulation_data(
start_time: datetime,
end_time: datetime,
device_ids: str,
scheme_type: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
scheme_name: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
获取 SCADA 关联的 link/node 模拟值
根据传入的 SCADA device_ids,找到关联的 link/node
并根据对应的 type,查询对应的模拟数据
"""
try:
# 手动解析 device_ids 为 List[str],去除空格
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
if scheme_type and scheme_name:
result = await CompositeQueries.get_scada_associated_scheme_simulation_data(
timescale_conn,
postgres_conn,
device_ids_list, # 使用解析后的列表
start_time,
end_time,
scheme_type,
scheme_name,
)
else:
result = (
await CompositeQueries.get_scada_associated_realtime_simulation_data(
timescale_conn,
postgres_conn,
device_ids_list, # 使用解析后的列表
start_time,
end_time,
)
)
if result is None:
raise HTTPException(status_code=404, detail="No simulation data found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/element-simulation")
async def get_feature_simulation_data(
start_time: datetime,
end_time: datetime,
feature_infos: str = Query(
..., description="特征信息,格式: id1:type1,id2:type2type为pipe或junction"
),
scheme_type: str = Query(None, description="指定方案类型,若为空则查询实时数据"),
scheme_name: str = Query(None, description="指定方案名称,若为空则查询实时数据"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
):
"""
获取 link/node 模拟值
根据传入的 featureInfos,找到关联的 link/node
并根据对应的 type,查询对应的模拟数据
Args:
feature_infos: 格式为 "element_id1:type1,element_id2:type2"
例如: "P1:pipe,J1:junction"
"""
try:
# 解析 feature_infos 为 List[Tuple[str, str]]
feature_infos_list = []
if feature_infos:
for item in feature_infos.split(","):
item = item.strip()
if ":" in item:
element_id, element_type = item.split(":", 1)
feature_infos_list.append(
(element_id.strip(), element_type.strip())
)
if not feature_infos_list:
raise HTTPException(status_code=400, detail="feature_infos cannot be empty")
if scheme_type and scheme_name:
result = await CompositeQueries.get_scheme_simulation_data(
timescale_conn,
feature_infos_list,
start_time,
end_time,
scheme_type,
scheme_name,
)
else:
result = await CompositeQueries.get_realtime_simulation_data(
timescale_conn,
feature_infos_list,
start_time,
end_time,
)
if result is None:
raise HTTPException(status_code=404, detail="No simulation data found")
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/element-scada")
async def get_element_associated_scada_data(
element_id: str,
start_time: datetime,
end_time: datetime,
use_cleaned: bool = Query(False, description="是否使用清洗后的数据"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
获取 link/node 关联的 SCADA 监测值
根据传入的 link/node id,匹配 SCADA 信息,
如果存在关联的 SCADA device_id,获取实际的监测数据
"""
try:
result = await CompositeQueries.get_element_associated_scada_data(
timescale_conn, postgres_conn, element_id, start_time, end_time, use_cleaned
)
if result is None:
raise HTTPException(
status_code=404, detail="No associated SCADA data found"
)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/composite/clean-scada")
async def clean_scada_data(
device_ids: str,
start_time: datetime = Query(...),
end_time: datetime = Query(...),
timescale_conn: AsyncConnection = Depends(get_database_connection),
postgres_conn: AsyncConnection = Depends(get_postgres_connection),
):
"""
清洗 SCADA 数据
根据 device_ids 查询 monitored_value,清洗后更新 cleaned_value
"""
try:
if device_ids == "all":
device_ids_list = []
else:
device_ids_list = (
[id.strip() for id in device_ids.split(",") if id.strip()]
if device_ids
else []
)
return await CompositeQueries.clean_scada_data(
timescale_conn, postgres_conn, device_ids_list, start_time, end_time
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/composite/pipeline-health-prediction")
async def predict_pipeline_health(
query_time: datetime = Query(..., description="查询时间"),
network_name: str = Query(..., description="管网数据库名称"),
timescale_conn: AsyncConnection = Depends(get_database_connection),
):
"""
预测管道健康状况
根据管网名称和当前时间,查询管道信息和实时数据,
使用随机生存森林模型预测管道的生存概率
Args:
query_time: 查询时间
db_name: 管网数据库名称
Returns:
预测结果列表,每个元素包含 link_id 和对应的生存函数
"""
try:
return await CompositeQueries.predict_pipeline_health(
timescale_conn, network_name, query_time
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
@@ -0,0 +1,647 @@
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
# 定义UTC+8时区
UTC_8 = timezone(timedelta(hours=8))
class RealtimeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance."""
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
await cur.execute(
"DELETE FROM realtime.link_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for realtime.link_simulation using DELETE then COPY for performance (sync version)."""
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
cur.execute(
"DELETE FROM realtime.link_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, link_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, link_id))
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_links_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM realtime.link_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.link_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, link_id))
@staticmethod
async def delete_links_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.link_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
# --- Node Simulation ---
@staticmethod
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
await cur.execute(
"DELETE FROM realtime.node_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
# 假设同一批次的数据时间是相同的
target_time = data[0]["time"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 先删除该时间点的旧数据
cur.execute(
"DELETE FROM realtime.node_simulation WHERE time = %s",
(target_time,),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, node_id: str
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s",
(start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_time_range(
conn: AsyncConnection,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> List[Dict[str, Any]]:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, node_id))
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_nodes_field_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, field: str
) -> dict:
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM realtime.node_simulation WHERE time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE realtime.node_simulation SET {} = %s WHERE time = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, node_id))
@staticmethod
async def delete_nodes_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM realtime.node_simulation WHERE time >= %s AND time <= %s",
(start_time, end_time),
)
# --- 复合查询 ---
@staticmethod
async def store_realtime_simulation_result(
conn: AsyncConnection,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
):
"""
Store realtime simulation results to TimescaleDB.
Args:
conn: Database connection
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串,解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间,转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
await RealtimeRepository.insert_nodes_batch(conn, node_data)
if link_data:
await RealtimeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_realtime_simulation_result_sync(
conn: Connection,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
):
"""
Store realtime simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串,解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间,转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
RealtimeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
RealtimeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_time_property(
conn: AsyncConnection,
query_time: str,
type: str,
property: str,
) -> list:
"""
Query all records by time and property from TimescaleDB.
Args:
conn: Database connection
query_time: Time to query (ISO format string)
type: Type of data ("node" or "link")
property: Property/field to query
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间,转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
data = await RealtimeRepository.get_nodes_field_by_time_range(
conn, start_time, end_time, property
)
elif type.lower() == "link":
data = await RealtimeRepository.get_links_field_by_time_range(
conn, start_time, end_time, property
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
# Format the results
result = []
for id, items in data.items():
for item in items:
result.append({"ID": id, "value": item["value"]})
return result
@staticmethod
async def query_simulation_result_by_id_time(
conn: AsyncConnection,
id: str,
type: str,
query_time: str,
) -> list[dict]:
"""
Query simulation results by id and time from TimescaleDB.
Args:
conn: Database connection
id: The id of the node or link
type: Type of data ("node" or "link")
query_time: Time to query (ISO format string)
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间,转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
return await RealtimeRepository.get_node_by_time_range(
conn, start_time, end_time, id
)
elif type.lower() == "link":
return await RealtimeRepository.get_link_by_time_range(
conn, start_time, end_time, id
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
+106
View File
@@ -0,0 +1,106 @@
from typing import List, Any
from datetime import datetime
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
class ScadaRepository:
@staticmethod
async def insert_scada_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
async with conn.cursor() as cur:
async with cur.copy(
"COPY scada.scada_data (time, device_id, monitored_value, cleaned_value) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["device_id"],
item.get("monitored_value"),
item.get("cleaned_value"),
)
)
@staticmethod
async def get_scada_by_ids_time_range(
conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s",
(device_ids, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
def get_scada_by_ids_time_range_sync(
conn: Connection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
) -> List[dict]:
with conn.cursor() as cur:
cur.execute(
"SELECT * FROM scada.scada_data WHERE device_id = ANY(%s) AND time >= %s AND time <= %s",
(device_ids, start_time, end_time),
)
return cur.fetchall()
@staticmethod
async def get_scada_field_by_id_time_range(
conn: AsyncConnection,
device_ids: List[str],
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
valid_fields = {"monitored_value", "cleaned_value"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT device_id, time, {} FROM scada.scada_data WHERE time >= %s AND time <= %s AND device_id = ANY(%s)"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (start_time, end_time, device_ids))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["device_id"]].append({
"time": row["time"].isoformat(),
"value": row[field]
})
return dict(result)
@staticmethod
async def update_scada_field(
conn: AsyncConnection, time: datetime, device_id: str, field: str, value: Any
):
valid_fields = {"monitored_value", "cleaned_value"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scada.scada_data SET {} = %s WHERE time = %s AND device_id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, device_id))
@staticmethod
async def delete_scada_by_id_time_range(
conn: AsyncConnection, device_id: str, start_time: datetime, end_time: datetime
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scada.scada_data WHERE device_id = %s AND time >= %s AND time <= %s",
(device_id, start_time, end_time),
)
+760
View File
@@ -0,0 +1,760 @@
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from collections import defaultdict
from psycopg import AsyncConnection, Connection, sql
import globals
# 定义UTC+8时区
UTC_8 = timezone(timedelta(hours=8))
class SchemeRepository:
# --- Link Simulation ---
@staticmethod
async def insert_links_batch(conn: AsyncConnection, data: List[dict]):
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance."""
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for scheme.link_simulation using DELETE then COPY for performance (sync version)."""
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
cur.execute(
"DELETE FROM scheme.link_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, link_id),
)
return await cur.fetchall()
@staticmethod
async def get_links_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_link_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
link_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, link_id)
)
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_links_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_link_field(
conn: AsyncConnection,
time: datetime,
scheme_type: str,
scheme_name: str,
link_id: str,
field: str,
value: Any,
):
valid_fields = {
"flow",
"friction",
"headloss",
"quality",
"reaction",
"setting",
"status",
"velocity",
}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.link_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme_type, scheme_name, link_id))
@staticmethod
async def delete_links_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.link_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- Node Simulation ---
@staticmethod
async def insert_nodes_batch(conn: AsyncConnection, data: List[dict]):
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
async with conn.transaction():
async with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
async with cur.copy(
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
await copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
# 获取批次中所有不同的时间点
all_times = list(set(item["time"] for item in data))
target_scheme_type = data[0]["scheme_type"]
target_scheme_name = data[0]["scheme_name"]
# 使用事务确保原子性
with conn.transaction():
with conn.cursor() as cur:
# 1. 删除该批次涉及的所有时间点、scheme_type、scheme_name 的旧数据
cur.execute(
"DELETE FROM scheme.node_simulation WHERE time = ANY(%s) AND scheme_type = %s AND scheme_name = %s",
(all_times, target_scheme_type, target_scheme_name),
)
# 2. 使用 COPY 快速写入新数据
with cur.copy(
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s",
(scheme_type, scheme_name, start_time, end_time, node_id),
)
return await cur.fetchall()
@staticmethod
async def get_nodes_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
) -> List[dict]:
async with conn.cursor() as cur:
await cur.execute(
"SELECT * FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
return await cur.fetchall()
@staticmethod
async def get_node_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
node_id: str,
field: str,
) -> List[Dict[str, Any]]:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(
query, (scheme_type, scheme_name, start_time, end_time, node_id)
)
rows = await cur.fetchall()
return [
{"time": row["time"].isoformat(), "value": row[field]} for row in rows
]
@staticmethod
async def get_nodes_field_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
field: str,
) -> dict:
# Validate field name to prevent SQL injection
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"SELECT id, time, {} FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (scheme_type, scheme_name, start_time, end_time))
rows = await cur.fetchall()
result = defaultdict(list)
for row in rows:
result[row["id"]].append(
{"time": row["time"].isoformat(), "value": row[field]}
)
return dict(result)
@staticmethod
async def update_node_field(
conn: AsyncConnection,
time: datetime,
scheme_type: str,
scheme_name: str,
node_id: str,
field: str,
value: Any,
):
valid_fields = {"actual_demand", "total_head", "pressure", "quality"}
if field not in valid_fields:
raise ValueError(f"Invalid field: {field}")
query = sql.SQL(
"UPDATE scheme.node_simulation SET {} = %s WHERE time = %s AND scheme_type = %s AND scheme_name = %s AND id = %s"
).format(sql.Identifier(field))
async with conn.cursor() as cur:
await cur.execute(query, (value, time, scheme_type, scheme_name, node_id))
@staticmethod
async def delete_nodes_by_scheme_and_time_range(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
start_time: datetime,
end_time: datetime,
):
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM scheme.node_simulation WHERE scheme_type = %s AND scheme_name = %s AND time >= %s AND time <= %s",
(scheme_type, scheme_name, start_time, end_time),
)
# --- 复合查询 ---
@staticmethod
async def store_scheme_simulation_result(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串,解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间,转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
await SchemeRepository.insert_nodes_batch(conn, node_data)
if link_data:
await SchemeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_scheme_simulation_result_sync(
conn: Connection,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串,解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间,转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
SchemeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
SchemeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_scheme_time_property(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
query_time: str,
type: str,
property: str,
) -> list:
"""
Query all records by scheme, time and property from TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
query_time: Time to query (ISO format string)
type: Type of data ("node" or "link")
property: Property/field to query
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间,转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
data = await SchemeRepository.get_nodes_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, property
)
elif type.lower() == "link":
data = await SchemeRepository.get_links_field_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, property
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
# Format the results
# Format the results
result = []
for id, items in data.items():
for item in items:
result.append({"ID": id, "value": item["value"]})
return result
@staticmethod
async def query_scheme_simulation_result_by_id_time(
conn: AsyncConnection,
scheme_type: str,
scheme_name: str,
id: str,
type: str,
query_time: str,
) -> list[dict]:
"""
Query scheme simulation results by id and time from TimescaleDB.
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
id: The id of the node or link
type: Type of data ("node" or "link")
query_time: Time to query (ISO format string)
Returns:
List of records matching the criteria
"""
# Convert query_time string to datetime
if isinstance(query_time, str):
if query_time.endswith("Z"):
# UTC时间,转换为UTC+8
utc_time = datetime.fromisoformat(query_time.replace("Z", "+00:00"))
target_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
target_time = datetime.fromisoformat(query_time)
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
else:
target_time = query_time
if target_time.tzinfo is None:
target_time = target_time.replace(tzinfo=UTC_8)
# Create time range: query_time ± 1 second
start_time = target_time - timedelta(seconds=1)
end_time = target_time + timedelta(seconds=1)
# Query based on type
if type.lower() == "node":
return await SchemeRepository.get_node_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, id
)
elif type.lower() == "link":
return await SchemeRepository.get_link_by_scheme_and_time_range(
conn, scheme_type, scheme_name, start_time, end_time, id
)
else:
raise ValueError(f"Invalid type: {type}. Must be 'node' or 'link'")
@@ -0,0 +1,36 @@
from dotenv import load_dotenv
import os
load_dotenv()
pg_name = os.getenv("TIMESCALEDB_DB_NAME")
pg_host = os.getenv("TIMESCALEDB_DB_HOST")
pg_port = os.getenv("TIMESCALEDB_DB_PORT")
pg_user = os.getenv("TIMESCALEDB_DB_USER")
pg_password = os.getenv("TIMESCALEDB_DB_PASSWORD")
def get_pgconn_string(
db_name=pg_name,
db_host=pg_host,
db_port=pg_port,
db_user=pg_user,
db_password=pg_password,
):
"""返回 PostgreSQL 连接字符串"""
return f"dbname={db_name} host={db_host} port={db_port} user={db_user} password={db_password}"
def get_pg_config():
"""返回 PostgreSQL 配置变量的字典"""
return {
"name": pg_name,
"host": pg_host,
"port": pg_port,
"user": pg_user,
}
def get_pg_password():
"""返回密码(谨慎使用)"""
return pg_password
View File