新增同步存储方法;新增run_server.py文件;修改默认的数据库连接方式;

This commit is contained in:
JIANG
2025-12-08 17:33:50 +08:00
parent 4fbdea435b
commit 44119c9725
9 changed files with 446 additions and 102 deletions

92
main.py
View File

@@ -1,11 +1,17 @@
import asyncio, os, io, json, time, pickle, redis, datetime, logging, threading, uvicorn, multiprocessing, asyncio, shutil, random
import os
import json
import time
import datetime
import logging
import threading
import shutil
import random
from typing import *
from typing import List, Annotated, Optional, Union
from urllib.request import Request
from xml.dom import minicompat
from pydantic import BaseModel
from starlette.responses import FileResponse, JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import Receive
from fastapi import (
FastAPI,
File,
@@ -13,47 +19,39 @@ from fastapi import (
Response,
status,
Request,
Body,
HTTPException,
Query,
Depends,
Header,
)
from fastapi.responses import PlainTextResponse
from fastapi.middleware.gzip import GZipMiddleware
from tjnetwork import *
from multiprocessing import Value
import uvicorn
import msgpack
from run_simulation import run_simulation, run_simulation_ex
from online_Analysis import *
from fastapi.middleware.cors import CORSMiddleware
from influxdb_client import (
InfluxDBClient,
BucketsApi,
WriteApi,
OrganizationsApi,
Point,
QueryApi,
)
from typing import List, Dict
from starlette.responses import FileResponse, JSONResponse
from contextlib import asynccontextmanager
from pydantic import BaseModel
from multiprocessing import Value
import redis
import msgpack
from datetime import datetime, timedelta, timezone
from dateutil import parser
import influxdb_info
# 第三方/自定义模块
import influxdb_api
import timescaledb
import py_linq
import time_api
import simulation
import globals
import os
import logging
import threading
import time
from logging.handlers import TimedRotatingFileHandler
from fastapi import FastAPI, APIRouter, Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer
from fastapi import FastAPI, Depends, HTTPException, Header
from typing import Annotated
import project_info
from timescaledb.database import db as tsdb
from postgresql.database import db as pgdb
from online_Analysis import *
from tjnetwork import *
JUNCTION = 0
RESERVOIR = 1
@@ -116,6 +114,25 @@ async def verify_token(authorization: Annotated[str, Header()] = None):
# app = FastAPI(dependencies=[Depends(global_auth)])
app = FastAPI()
# 生命周期管理器
@asynccontextmanager
async def lifespan(app: FastAPI):
# 初始化数据库连接池
tsdb.init_pool()
pgdb.init_pool()
await tsdb.open()
await pgdb.open()
yield
# 清理资源
tsdb.close()
pgdb.close()
app = FastAPI(lifespan=lifespan)
app.include_router(timescaledb.router)
access_tokens = []
@@ -3466,9 +3483,11 @@ async def fastapi_run_simulation_manually_by_date(
item["name"], base_date, item["start_time"], item["duration"]
)
)
thread.start()
thread.join() # 等待线程完成
return {"status": "success"}
except Exception as e:
return {"status": "error", "message": str(e)}
@@ -4213,11 +4232,10 @@ async def get_dict(item: Item):
if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)
# url='http://127.0.0.1:8000/valve_close_analysis?network=beibeizone&start_time=2024-04-01T08:00:00Z&valve_IDs=GSD2307192058577780A3287D78&valve_IDs=GSD2307192058572E953B707226(S2)&duration=1800'
# url='http://127.0.0.1:8000/burst_analysis?network=beibeizone&start_time=2024-04-01T08:00:00Z&burst_ID=ZBBGXSZW000001&duration=1800'
url='http://127.0.0.1:8000/burst_analysis?network=beibeizone&start_time=2024-04-01T08:00:00Z&burst_ID=ZBBGXSZW000001&duration=1800'
# url = "http://192.168.1.36:8000/queryallschemeallrecords/?schemename=Fangan0817114448&querydate=2025-08-13&schemetype=burst_Analysis"
# response = Request.get(url)
# import requests
import requests
# response = requests.get(url)
print(get_all_scada_info("szh"))
response = requests.get(url)

View File

@@ -17,8 +17,7 @@ class Database:
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
target_db_name = db_name or self.db_name
conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name)
conn_string = postgresql_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
@@ -28,7 +27,7 @@ class Database:
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(
f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}"
f"PostgreSQL connection pool initialized for database: 'default'"
)
except Exception as e:
logger.error(f"Failed to initialize postgresql connection pool: {e}")

6
run_server.py Normal file
View File

@@ -0,0 +1,6 @@
import asyncio
import uvicorn
if __name__ == "__main__":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
uvicorn.run("main:app", host="0.0.0.0", port=8000)

View File

@@ -21,7 +21,6 @@ import globals
import uuid
import project_info
from api.postgresql_info import get_pgconn_string
import asyncio
from timescaledb.internal_queries import InternalStorage as TimescaleInternalStorage
logging.basicConfig(
@@ -1231,13 +1230,10 @@ def run_simulation(
# print(node_result)
# 存储
if simulation_type.upper() == "REALTIME":
asyncio.run(
TimescaleInternalStorage.store_realtime_simulation(
node_result, link_result, modify_pattern_start_time
)
)
elif simulation_type.upper() == "EXTENDED":
asyncio.run(
TimescaleInternalStorage.store_scheme_simulation(
scheme_Type,
scheme_Name,
@@ -1246,7 +1242,6 @@ def run_simulation(
modify_pattern_start_time,
num_periods_result,
)
)
# 暂不需要再次存储 SCADA 模拟信息
# TimescaleInternalStorage.fill_scheme_simulation_result_to_SCADA(scheme_Type=scheme_Type, scheme_Name=scheme_Name)

View File

@@ -8,6 +8,7 @@ import timescaledb.timescaledb_info as timescaledb_info
# Configure logging
logger = logging.getLogger(__name__)
class Database:
def __init__(self, db_name=None):
self.pool = None
@@ -16,17 +17,18 @@ class Database:
def init_pool(self, db_name=None):
"""Initialize the connection pool."""
# Use provided db_name, or the one from constructor, or default from config
target_db_name = db_name or self.db_name
conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name)
conn_string = timescaledb_info.get_pgconn_string()
try:
self.pool = psycopg_pool.AsyncConnectionPool(
conninfo=conn_string,
min_size=1,
max_size=20,
open=False, # Don't open immediately, wait for startup
kwargs={"row_factory": dict_row} # Return rows as dictionaries
kwargs={"row_factory": dict_row}, # Return rows as dictionaries
)
logger.info(
f"TimescaleDB connection pool initialized for database: default"
)
logger.info(f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}")
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB connection pool: {e}")
raise
@@ -41,6 +43,11 @@ class Database:
await self.pool.close()
logger.info("TimescaleDB connection pool closed.")
def get_pgconn_string(self, db_name=None):
"""Get the TimescaleDB connection string."""
target_db_name = db_name or self.db_name
return timescaledb_info.get_pgconn_string(db_name=target_db_name)
@asynccontextmanager
async def get_connection(self) -> AsyncGenerator:
"""Get a connection from the pool."""
@@ -50,16 +57,19 @@ class Database:
async with self.pool.connection() as conn:
yield conn
# 默认数据库实例
db = Database()
# 缓存不同数据库的实例 - 避免重复创建连接池
_database_instances: Dict[str, Database] = {}
def create_database_instance(db_name):
"""Create a new Database instance for a specific database."""
return Database(db_name=db_name)
async def get_database_instance(db_name: Optional[str] = None) -> Database:
"""Get or create a database instance for the specified database name."""
if not db_name:
@@ -75,11 +85,13 @@ async def get_database_instance(db_name: Optional[str] = None) -> Database:
return _database_instances[db_name]
async def get_db_connection():
"""Dependency for FastAPI to get a database connection."""
async with db.get_connection() as conn:
yield conn
async def get_database_connection(db_name: Optional[str] = None):
"""
FastAPI dependency to get database connection with optional database name.
@@ -90,6 +102,7 @@ async def get_database_connection(db_name: Optional[str] = None):
async with instance.get_connection() as conn:
yield conn
async def cleanup_database_instances():
"""Clean up all database instances (call this on application shutdown)."""
for db_name, instance in _database_instances.items():

View File

@@ -1,28 +1,49 @@
from typing import List
from fastapi.logger import logger
from timescaledb.schemas.scheme import SchemeRepository
from timescaledb.schemas.realtime import RealtimeRepository
from timescaledb.database import get_database_instance
import timescaledb.timescaledb_info as timescaledb_info
import psycopg
import time
# 内部使用存储类
class InternalStorage:
@staticmethod
async def store_realtime_simulation(
def store_realtime_simulation(
node_result_list: List[dict],
link_result_list: List[dict],
result_start_time: str,
db_name: str = None,
max_retries: int = 3,
):
"""存储实时模拟结果"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
await RealtimeRepository.store_realtime_simulation_result(
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
starttime = time.time()
RealtimeRepository.store_realtime_simulation_result_sync(
conn, node_result_list, link_result_list, result_start_time
)
endtime = time.time()
logger.info(f"存储实时模拟结果耗时: {endtime - starttime}")
break # 成功
except Exception as e:
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1) # 重试前等待
else:
raise # 达到最大重试次数后抛出异常
@staticmethod
async def store_scheme_simulation(
def store_scheme_simulation(
scheme_type: str,
scheme_name: str,
node_result_list: List[dict],
@@ -30,11 +51,18 @@ class InternalStorage:
result_start_time: str,
num_periods: int = 1,
db_name: str = None,
max_retries: int = 3,
):
"""存储方案模拟结果"""
instance = await get_database_instance(db_name)
async with instance.get_connection() as conn:
await SchemeRepository.store_scheme_simulation_result(
for attempt in range(max_retries):
try:
conn_string = (
timescaledb_info.get_pgconn_string(db_name=db_name)
if db_name
else timescaledb_info.get_pgconn_string()
)
with psycopg.Connection.connect(conn_string) as conn:
SchemeRepository.store_scheme_simulation_result_sync(
conn,
scheme_type,
scheme_name,
@@ -43,3 +71,10 @@ class InternalStorage:
result_start_time,
num_periods,
)
break # 成功
except Exception as e:
logger.error(f"存储尝试 {attempt + 1} 失败: {e}")
if attempt < max_retries - 1:
time.sleep(1) # 重试前等待
else:
raise # 达到最大重试次数后抛出异常

View File

@@ -358,7 +358,7 @@ async def insert_scada_data(
@router.get("/scada")
async def get_scada_data(
async def get_scada_by_id_time_range(
device_id: str,
start_time: datetime,
end_time: datetime,
@@ -370,7 +370,7 @@ async def get_scada_data(
@router.get("/scada/{device_id}/field")
async def get_scada_field(
async def get_scada_field_by_id_time_range(
device_id: str,
start_time: datetime,
end_time: datetime,

View File

@@ -1,6 +1,6 @@
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from psycopg import AsyncConnection, sql
from psycopg import AsyncConnection, Connection, sql
# 定义UTC+8时区
UTC_8 = timezone(timedelta(hours=8))
@@ -36,6 +36,32 @@ class RealtimeRepository:
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for realtime.link_simulation using COPY for performance (sync version)."""
if not data:
return
with conn.cursor() as cur:
with cur.copy(
"COPY realtime.link_simulation (time, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, link_id: str
@@ -180,6 +206,27 @@ class RealtimeRepository:
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
with conn.cursor() as cur:
with cur.copy(
"COPY realtime.node_simulation (time, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_time_range(
conn: AsyncConnection, start_time: datetime, end_time: datetime, node_id: str
@@ -309,32 +356,36 @@ class RealtimeRepository:
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_result.get("id"),
"actual_demand": node_result.get("actual_demand"),
"total_head": node_result.get("total_head"),
"pressure": node_result.get("pressure"),
"quality": node_result.get("quality"),
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_result.get("id"),
"flow": link_result.get("flow"),
"friction": link_result.get("friction"),
"headloss": link_result.get("headloss"),
"quality": link_result.get("quality"),
"reaction": link_result.get("reaction"),
"setting": link_result.get("setting"),
"status": link_result.get("status"),
"velocity": link_result.get("velocity"),
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
@@ -345,6 +396,84 @@ class RealtimeRepository:
if link_data:
await RealtimeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_realtime_simulation_result_sync(
conn: Connection,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
):
"""
Store realtime simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
data = node_result.get("result", [])[0] # 实时模拟只有一个周期
node_data.append(
{
"time": simulation_time,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
data = link_result.get("result", [])[0]
link_data.append(
{
"time": simulation_time,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
RealtimeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
RealtimeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_time_property(
conn: AsyncConnection,

View File

@@ -1,6 +1,6 @@
from typing import List, Any, Dict
from datetime import datetime, timedelta, timezone
from psycopg import AsyncConnection, sql
from psycopg import AsyncConnection, Connection, sql
import globals
# 定义UTC+8时区
@@ -39,6 +39,34 @@ class SchemeRepository:
)
)
@staticmethod
def insert_links_batch_sync(conn: Connection, data: List[dict]):
"""Batch insert for scheme.link_simulation using COPY for performance (sync version)."""
if not data:
return
with conn.cursor() as cur:
with cur.copy(
"COPY scheme.link_simulation (time, scheme_type, scheme_name, id, flow, friction, headloss, quality, reaction, setting, status, velocity) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("flow"),
item.get("friction"),
item.get("headloss"),
item.get("quality"),
item.get("reaction"),
item.get("setting"),
item.get("status"),
item.get("velocity"),
)
)
@staticmethod
async def get_link_by_scheme_and_time_range(
conn: AsyncConnection,
@@ -206,6 +234,29 @@ class SchemeRepository:
)
)
@staticmethod
def insert_nodes_batch_sync(conn: Connection, data: List[dict]):
if not data:
return
with conn.cursor() as cur:
with cur.copy(
"COPY scheme.node_simulation (time, scheme_type, scheme_name, id, actual_demand, total_head, pressure, quality) FROM STDIN"
) as copy:
for item in data:
copy.write_row(
(
item["time"],
item["scheme_type"],
item["scheme_name"],
item["id"],
item.get("actual_demand"),
item.get("total_head"),
item.get("pressure"),
item.get("quality"),
)
)
@staticmethod
async def get_node_by_scheme_and_time_range(
conn: AsyncConnection,
@@ -421,6 +472,104 @@ class SchemeRepository:
if link_data:
await SchemeRepository.insert_links_batch(conn, link_data)
@staticmethod
def store_scheme_simulation_result_sync(
conn: Connection,
scheme_type: str,
scheme_name: str,
node_result_list: List[Dict[str, any]],
link_result_list: List[Dict[str, any]],
result_start_time: str,
num_periods: int = 1,
):
"""
Store scheme simulation results to TimescaleDB (sync version).
Args:
conn: Database connection
scheme_type: Scheme type
scheme_name: Scheme name
node_result_list: List of node simulation results
link_result_list: List of link simulation results
result_start_time: Start time for the results (ISO format string)
"""
# Convert result_start_time string to datetime if needed
if isinstance(result_start_time, str):
# 如果是ISO格式字符串解析并转换为UTC+8
if result_start_time.endswith("Z"):
# UTC时间转换为UTC+8
utc_time = datetime.fromisoformat(
result_start_time.replace("Z", "+00:00")
)
simulation_time = utc_time.astimezone(UTC_8)
else:
# 假设已经是UTC+8时间
simulation_time = datetime.fromisoformat(result_start_time)
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
else:
simulation_time = result_start_time
if simulation_time.tzinfo is None:
simulation_time = simulation_time.replace(tzinfo=UTC_8)
timestep_parts = globals.hydraulic_timestep.split(":")
timestep = timedelta(
hours=int(timestep_parts[0]),
minutes=int(timestep_parts[1]),
seconds=int(timestep_parts[2]),
)
# Prepare node data for batch insert
node_data = []
for node_result in node_result_list:
node_id = node_result.get("node")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = node_result.get("result", [])[period_index]
node_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": node_id,
"actual_demand": data.get("demand"),
"total_head": data.get("head"),
"pressure": data.get("pressure"),
"quality": data.get("quality"),
}
)
# Prepare link data for batch insert
link_data = []
for link_result in link_result_list:
link_id = link_result.get("link")
for period_index in range(num_periods):
current_time = simulation_time + (timestep * period_index)
data = link_result.get("result", [])[period_index]
link_data.append(
{
"time": current_time,
"scheme_type": scheme_type,
"scheme_name": scheme_name,
"id": link_id,
"flow": data.get("flow"),
"friction": data.get("friction"),
"headloss": data.get("headloss"),
"quality": data.get("quality"),
"reaction": data.get("reaction"),
"setting": data.get("setting"),
"status": data.get("status"),
"velocity": data.get("velocity"),
}
)
# Insert data using batch methods
if node_data:
SchemeRepository.insert_nodes_batch_sync(conn, node_data)
if link_data:
SchemeRepository.insert_links_batch_sync(conn, link_data)
@staticmethod
async def query_all_record_by_scheme_time_property(
conn: AsyncConnection,