diff --git a/app/api/v1/endpoints/project.py b/app/api/v1/endpoints/project.py index b503181..84d2c87 100644 --- a/app/api/v1/endpoints/project.py +++ b/app/api/v1/endpoints/project.py @@ -4,6 +4,8 @@ from fastapi.responses import PlainTextResponse from typing import Any, Dict import app.services.project_info as project_info from app.native.api import ChangeSet +from app.infra.db.postgresql.database import get_database_instance as get_pg_db +from app.infra.db.timescaledb.database import get_database_instance as get_ts_db from app.services.tjnetwork import ( list_project, have_project, @@ -62,6 +64,28 @@ async def is_project_open_endpoint(network: str): @router.post("/openproject/") async def open_project_endpoint(network: str): open_project(network) + + # 尝试连接指定数据库 + try: + # 初始化 PostgreSQL 连接池 + pg_instance = await get_pg_db(network) + async with pg_instance.get_connection() as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + + # 初始化 TimescaleDB 连接池 + ts_instance = await get_ts_db(network) + async with ts_instance.get_connection() as conn: + async with conn.cursor() as cur: + await cur.execute("SELECT 1") + + except Exception as e: + # 记录错误但不阻断项目打开,或者根据需求决定是否阻断 + # 这里选择打印错误,因为 open_project 原本只负责原生部分 + print(f"Failed to connect to databases for {network}: {str(e)}") + # 如果数据库连接是必须的,可以抛出异常: + # raise HTTPException(status_code=500, detail=f"Database connection failed: {str(e)}") + return network @router.post("/closeproject/") diff --git a/app/infra/db/postgresql/database.py b/app/infra/db/postgresql/database.py index 1fb0baa..012ec28 100644 --- a/app/infra/db/postgresql/database.py +++ b/app/infra/db/postgresql/database.py @@ -18,7 +18,12 @@ class Database: """Initialize the connection pool.""" # Use provided db_name, or the one from constructor, or default from config target_db_name = db_name or self.db_name - conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name) + + # Get connection string, handling default case where target_db_name might be None + if target_db_name: + conn_string = postgresql_info.get_pgconn_string(db_name=target_db_name) + else: + conn_string = postgresql_info.get_pgconn_string() try: self.pool = psycopg_pool.AsyncConnectionPool( diff --git a/app/infra/db/postgresql/router.py b/app/infra/db/postgresql/router.py index f8d8ccd..6a03d24 100644 --- a/app/infra/db/postgresql/router.py +++ b/app/infra/db/postgresql/router.py @@ -1,7 +1,6 @@ -from fastapi import APIRouter, Depends, HTTPException, Query, Body +from fastapi import APIRouter, Depends, HTTPException, Query from typing import Optional from psycopg import AsyncConnection -from pydantic import BaseModel from .database import get_database_instance from .scada_info import ScadaRepository @@ -10,10 +9,6 @@ from .scheme import SchemeRepository router = APIRouter() -class DatabaseConfig(BaseModel): - db_name: str - - # 创建支持数据库选择的连接依赖函数 async def get_database_connection( db_name: Optional[str] = Query( @@ -26,26 +21,6 @@ async def get_database_connection( yield conn -@router.post("/postgres/open-database") -async def open_database(config: DatabaseConfig): - """ - 尝试连接指定数据库,如果成功则返回成功消息 - """ - try: - instance = await get_database_instance(config.db_name) - # 尝试获取连接并执行简单查询以验证连接 - async with instance.get_connection() as conn: - async with conn.cursor() as cur: - await cur.execute("SELECT 1") - await cur.fetchone() - - return {"success": True, "message": f"Successfully connected to database: {config.db_name}"} - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to connect to database {config.db_name}: {str(e)}" - ) - - @router.get("/scada-info") async def get_scada_info_with_connection( conn: AsyncConnection = Depends(get_database_connection), diff --git a/app/infra/db/timescaledb/database.py b/app/infra/db/timescaledb/database.py index e16b4c7..12f33e2 100644 --- a/app/infra/db/timescaledb/database.py +++ b/app/infra/db/timescaledb/database.py @@ -18,7 +18,13 @@ class Database: """Initialize the connection pool.""" # Use provided db_name, or the one from constructor, or default from config target_db_name = db_name or self.db_name - conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name) + + # Get connection string, handling default case where target_db_name might be None + if target_db_name: + conn_string = timescaledb_info.get_pgconn_string(db_name=target_db_name) + else: + conn_string = timescaledb_info.get_pgconn_string() + try: self.pool = psycopg_pool.AsyncConnectionPool( conninfo=conn_string, @@ -47,7 +53,9 @@ class Database: def get_pgconn_string(self, db_name=None): """Get the TimescaleDB connection string.""" target_db_name = db_name or self.db_name - return timescaledb_info.get_pgconn_string(db_name=target_db_name) + if target_db_name: + return timescaledb_info.get_pgconn_string(db_name=target_db_name) + return timescaledb_info.get_pgconn_string() @asynccontextmanager async def get_connection(self) -> AsyncGenerator: diff --git a/app/infra/db/timescaledb/router.py b/app/infra/db/timescaledb/router.py index d988780..6ae2b26 100644 --- a/app/infra/db/timescaledb/router.py +++ b/app/infra/db/timescaledb/router.py @@ -2,7 +2,6 @@ from fastapi import APIRouter, Depends, HTTPException, Query from typing import List, Optional from datetime import datetime from psycopg import AsyncConnection -from pydantic import BaseModel from .database import get_database_instance from .schemas.realtime import RealtimeRepository @@ -16,10 +15,6 @@ from app.infra.db.postgresql.database import ( router = APIRouter() -class DatabaseConfig(BaseModel): - db_name: str - - # 创建支持数据库选择的连接依赖函数 async def get_database_connection( db_name: Optional[str] = Query( @@ -44,26 +39,6 @@ async def get_postgres_connection( yield conn -@router.post("/timescaledb/open-database") -async def open_database(config: DatabaseConfig): - """ - 尝试连接指定数据库,如果成功则返回成功消息 - """ - try: - instance = await get_database_instance(config.db_name) - # 尝试获取连接并执行简单查询以验证连接 - async with instance.get_connection() as conn: - async with conn.cursor() as cur: - await cur.execute("SELECT 1") - await cur.fetchone() - - return {"success": True, "message": f"Successfully connected to database: {config.db_name}"} - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Failed to connect to database {config.db_name}: {str(e)}" - ) - - # --- Realtime Endpoints ---