diff --git a/app/infra/db/postgresql/database.py b/app/infra/db/postgresql/database.py index c9b9fb2..1fb0baa 100644 --- a/app/infra/db/postgresql/database.py +++ b/app/infra/db/postgresql/database.py @@ -17,7 +17,9 @@ class Database: def init_pool(self, db_name=None): """Initialize the connection pool.""" # Use provided db_name, or the one from constructor, or default from config - conn_string = postgresql_info.get_pgconn_string() + target_db_name = db_name or self.db_name + conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name) + try: self.pool = psycopg_pool.AsyncConnectionPool( conninfo=conn_string, @@ -26,7 +28,7 @@ class Database: open=False, # Don't open immediately, wait for startup kwargs={"row_factory": dict_row}, # Return rows as dictionaries ) - logger.info(f"PostgreSQL connection pool initialized for database: default") + logger.info(f"PostgreSQL connection pool initialized for database: {target_db_name or 'default'}") except Exception as e: logger.error(f"Failed to initialize postgresql connection pool: {e}") raise diff --git a/app/infra/db/postgresql/router.py b/app/infra/db/postgresql/router.py index 6a03d24..f8d8ccd 100644 --- a/app/infra/db/postgresql/router.py +++ b/app/infra/db/postgresql/router.py @@ -1,6 +1,7 @@ -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, Body from typing import Optional from psycopg import AsyncConnection +from pydantic import BaseModel from .database import get_database_instance from .scada_info import ScadaRepository @@ -9,6 +10,10 @@ from .scheme import SchemeRepository router = APIRouter() +class DatabaseConfig(BaseModel): + db_name: str + + # 创建支持数据库选择的连接依赖函数 async def get_database_connection( db_name: Optional[str] = Query( @@ -21,6 +26,26 @@ async def get_database_connection( yield conn +@router.post("/postgres/open-database") +async def open_database(config: DatabaseConfig): + """ + 尝试连接指定数据库,如果成功则返回成功消息 + """ + try: + instance = await get_database_instance(config.db_name) + # 尝试获取连接并执行简单查询以验证连接 + async with instance.get_connection() as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + await cur.fetchone() + + return {"success": True, "message": f"Successfully connected to database: {config.db_name}"} + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to connect to database {config.db_name}: {str(e)}" + ) + + @router.get("/scada-info") async def get_scada_info_with_connection( conn: AsyncConnection = Depends(get_database_connection), diff --git a/app/infra/db/timescaledb/database.py b/app/infra/db/timescaledb/database.py index e3726fc..e16b4c7 100644 --- a/app/infra/db/timescaledb/database.py +++ b/app/infra/db/timescaledb/database.py @@ -17,7 +17,8 @@ class Database: def init_pool(self, db_name=None): """Initialize the connection pool.""" # Use provided db_name, or the one from constructor, or default from config - conn_string = timescaledb_info.get_pgconn_string() + target_db_name = db_name or self.db_name + conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name) try: self.pool = psycopg_pool.AsyncConnectionPool( conninfo=conn_string, @@ -27,7 +28,7 @@ class Database: kwargs={"row_factory": dict_row}, # Return rows as dictionaries ) logger.info( - f"TimescaleDB connection pool initialized for database: default" + f"TimescaleDB connection pool initialized for database: {target_db_name or 'default'}" ) except Exception as e: logger.error(f"Failed to initialize TimescaleDB connection pool: {e}") diff --git a/app/infra/db/timescaledb/router.py b/app/infra/db/timescaledb/router.py index 6ae2b26..d988780 100644 --- a/app/infra/db/timescaledb/router.py +++ b/app/infra/db/timescaledb/router.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query from typing import List, Optional from datetime import datetime from psycopg import AsyncConnection +from pydantic import BaseModel from .database import get_database_instance from .schemas.realtime import RealtimeRepository @@ -15,6 +16,10 @@ from app.infra.db.postgresql.database import ( router = APIRouter() +class DatabaseConfig(BaseModel): + db_name: str + + # 创建支持数据库选择的连接依赖函数 async def get_database_connection( db_name: Optional[str] = Query( @@ -39,6 +44,26 @@ async def get_postgres_connection( yield conn +@router.post("/timescaledb/open-database") +async def open_database(config: DatabaseConfig): + """ + 尝试连接指定数据库,如果成功则返回成功消息 + """ + try: + instance = await get_database_instance(config.db_name) + # 尝试获取连接并执行简单查询以验证连接 + async with instance.get_connection() as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + await cur.fetchone() + + return {"success": True, "message": f"Successfully connected to database: {config.db_name}"} + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to connect to database {config.db_name}: {str(e)}" + ) + + # --- Realtime Endpoints --- diff --git a/app/services/simulation.py b/app/services/simulation.py index 4b7e976..c7eca9a 100644 --- a/app/services/simulation.py +++ b/app/services/simulation.py @@ -1235,7 +1235,7 @@ def run_simulation( starttime = time.time() if simulation_type.upper() == "REALTIME": TimescaleInternalStorage.store_realtime_simulation( - node_result, link_result, modify_pattern_start_time + node_result, link_result, modify_pattern_start_time, db_name=name ) elif simulation_type.upper() == "EXTENDED": TimescaleInternalStorage.store_scheme_simulation( @@ -1245,6 +1245,7 @@ def run_simulation( link_result, modify_pattern_start_time, num_periods_result, + db_name=name, ) endtime = time.time() logging.info("store time: %f", endtime - starttime)