From eb45e4aaa5d65a68ba04fd2e378034ea09ae04fd Mon Sep 17 00:00:00 2001 From: Jiang Date: Wed, 11 Feb 2026 10:42:40 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=E4=BB=A3=E7=A0=81=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=A1=B9=E7=9B=AE=E5=88=87=E6=8D=A2=EF=BC=8C?= =?UTF-8?q?=E6=89=93=E5=BC=80=E4=B8=8D=E5=90=8C=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E7=9A=84=E8=BF=9E=E6=8E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/infra/db/postgresql/database.py | 6 ++++-- app/infra/db/postgresql/router.py | 27 ++++++++++++++++++++++++++- app/infra/db/timescaledb/database.py | 5 +++-- app/infra/db/timescaledb/router.py | 25 +++++++++++++++++++++++++ app/services/simulation.py | 3 ++- 5 files changed, 60 insertions(+), 6 deletions(-) diff --git a/app/infra/db/postgresql/database.py b/app/infra/db/postgresql/database.py index c9b9fb2..1fb0baa 100644 --- a/app/infra/db/postgresql/database.py +++ b/app/infra/db/postgresql/database.py @@ -17,7 +17,9 @@ class Database: def init_pool(self, db_name=None): """Initialize the connection pool.""" # Use provided db_name, or the one from constructor, or default from config - conn_string = postgresql_info.get_pgconn_string() + target_db_name = db_name or self.db_name + conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name) + try: self.pool = psycopg_pool.AsyncConnectionPool( conninfo=conn_string, @@ -26,7 +28,7 @@ class Database: open=False, # Don't open immediately, wait for startup kwargs={"row_factory": dict_row}, # Return rows as dictionaries ) - logger.info(f"PostgreSQL connection pool initialized for database: default") + logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}") except Exception as e: logger.error(f"Failed to initialize postgresql connection pool: {e}") raise diff --git a/app/infra/db/postgresql/router.py b/app/infra/db/postgresql/router.py index 6a03d24..f8d8ccd 100644 --- a/app/infra/db/postgresql/router.py +++ b/app/infra/db/postgresql/router.py @@ -1,6 +1,7 @@ -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, Body from typing import Optional from psycopg import AsyncConnection +from pydantic import BaseModel from .database import get_database_instance from .scada_info import ScadaRepository @@ -9,6 +10,10 @@ from .scheme import SchemeRepository router = APIRouter() +class DatabaseConfig(BaseModel): + db_name: str + + # 创建支持数据库选择的连接依赖函数 async def get_database_connection( db_name: Optional[str] = Query( @@ -21,6 +26,26 @@ async def get_database_connection( yield conn +@router.post("/postgres/open-database") +async def open_database(config: DatabaseConfig): + """ + 尝试连接指定数据库,如果成功则返回成功消息 + """ + try: + instance = await get_database_instance(config.db_name) + # 尝试获取连接并执行简单查询以验证连接 + async with instance.get_connection() as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + await cur.fetchone() + + return {"success": True, "message": f"Successfully connected to database: {config.db_name}"} + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to connect to database {config.db_name}: {str(e)}" + ) + + @router.get("/scada-info") async def get_scada_info_with_connection( conn: AsyncConnection = Depends(get_database_connection), diff --git a/app/infra/db/timescaledb/database.py b/app/infra/db/timescaledb/database.py index e3726fc..e16b4c7 100644 --- a/app/infra/db/timescaledb/database.py +++ b/app/infra/db/timescaledb/database.py @@ -17,7 +17,8 @@ class Database: def init_pool(self, db_name=None): """Initialize the connection pool.""" # Use provided db_name, or the one from constructor, or default from config - conn_string = timescaledb_info.get_pgconn_string() + target_db_name = db_name or self.db_name + conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name) try: self.pool = psycopg_pool.AsyncConnectionPool( conninfo=conn_string, @@ -27,7 +28,7 @@ class Database: kwargs={"row_factory": dict_row}, # Return rows as dictionaries ) logger.info( - f"TimescaleDB connection pool initialized for database: default" + f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}" ) except Exception as e: logger.error(f"Failed to initialize TimescaleDB connection pool: {e}") diff --git a/app/infra/db/timescaledb/router.py b/app/infra/db/timescaledb/router.py index 6ae2b26..d988780 100644 --- a/app/infra/db/timescaledb/router.py +++ b/app/infra/db/timescaledb/router.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query from typing import List, Optional from datetime import datetime from psycopg import AsyncConnection +from pydantic import BaseModel from .database import get_database_instance from .schemas.realtime import RealtimeRepository @@ -15,6 +16,10 @@ from app.infra.db.postgresql.database import ( router = APIRouter() +class DatabaseConfig(BaseModel): + db_name: str + + # 创建支持数据库选择的连接依赖函数 async def get_database_connection( db_name: Optional[str] = Query( @@ -39,6 +44,26 @@ async def get_postgres_connection( yield conn +@router.post("/timescaledb/open-database") +async def open_database(config: DatabaseConfig): + """ + 尝试连接指定数据库,如果成功则返回成功消息 + """ + try: + instance = await get_database_instance(config.db_name) + # 尝试获取连接并执行简单查询以验证连接 + async with instance.get_connection() as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + await cur.fetchone() + + return {"success": True, "message": f"Successfully connected to database: {config.db_name}"} + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to connect to database {config.db_name}: {str(e)}" + ) + + # --- Realtime Endpoints --- diff --git a/app/services/simulation.py b/app/services/simulation.py index 4b7e976..c7eca9a 100644 --- a/app/services/simulation.py +++ b/app/services/simulation.py @@ -1235,7 +1235,7 @@ def run_simulation( starttime = time.time() if simulation_type.upper() == "REALTIME": TimescaleInternalStorage.store_realtime_simulation( - node_result, link_result, modify_pattern_start_time + node_result, link_result, modify_pattern_start_time, db_name=name ) elif simulation_type.upper() == "EXTENDED": TimescaleInternalStorage.store_scheme_simulation( @@ -1245,6 +1245,7 @@ def run_simulation( link_result, modify_pattern_start_time, num_periods_result, + db_name=name, ) endtime = time.time() logging.info("store time: %f", endtime - starttime)