from fastapi import APIRouter, Depends, HTTPException, Query, Body from typing import Optional from psycopg import AsyncConnection from pydantic import BaseModel from .database import get_database_instance from .scada_info import ScadaRepository from .scheme import SchemeRepository router = APIRouter() class DatabaseConfig(BaseModel): db_name: str # 创建支持数据库选择的连接依赖函数 async def get_database_connection( db_name: Optional[str] = Query( None, description="指定要连接的数据库名称,为空时使用默认数据库" ) ): """获取数据库连接,支持通过查询参数指定数据库名称""" instance = await get_database_instance(db_name) async with instance.get_connection() as conn: yield conn @router.post("/postgres/open-database") async def open_database(config: DatabaseConfig): """ 尝试连接指定数据库,如果成功则返回成功消息 """ try: instance = await get_database_instance(config.db_name) # 尝试获取连接并执行简单查询以验证连接 async with instance.get_connection() as conn: async with conn.cursor() as cur: await cur.execute("SELECT 1") await cur.fetchone() return {"success": True, "message": f"Successfully connected to database: {config.db_name}"} except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to connect to database {config.db_name}: {str(e)}" ) @router.get("/scada-info") async def get_scada_info_with_connection( conn: AsyncConnection = Depends(get_database_connection), ): """ 使用连接池查询所有SCADA信息 """ try: # 使用ScadaRepository查询SCADA信息 scada_data = await ScadaRepository.get_scadas(conn) return {"success": True, "data": scada_data, "count": len(scada_data)} except Exception as e: raise HTTPException( status_code=500, detail=f"查询SCADA信息时发生错误: {str(e)}" ) @router.get("/scheme-list") async def get_scheme_list_with_connection( conn: AsyncConnection = Depends(get_database_connection), ): """ 使用连接池查询所有方案信息 """ try: # 使用SchemeRepository查询方案信息 scheme_data = await SchemeRepository.get_schemes(conn) return {"success": True, "data": scheme_data, "count": len(scheme_data)} except Exception as e: raise HTTPException(status_code=500, detail=f"查询方案信息时发生错误: {str(e)}") @router.get("/burst-locate-result") async def get_burst_locate_result_with_connection( conn: AsyncConnection = Depends(get_database_connection), ): """ 使用连接池查询所有爆管定位结果 """ try: # 使用SchemeRepository查询爆管定位结果 burst_data = await SchemeRepository.get_burst_locate_results(conn) return {"success": True, "data": burst_data, "count": len(burst_data)} except Exception as e: raise HTTPException( status_code=500, detail=f"查询爆管定位结果时发生错误: {str(e)}" ) @router.get("/burst-locate-result/{burst_incident}") async def get_burst_locate_result_by_incident( burst_incident: str, conn: AsyncConnection = Depends(get_database_connection), ): """ 根据 burst_incident 查询爆管定位结果 """ try: # 使用SchemeRepository查询爆管定位结果 return await SchemeRepository.get_burst_locate_result_by_incident( conn, burst_incident ) except Exception as e: raise HTTPException( status_code=500, detail=f"根据 burst_incident 查询爆管定位结果时发生错误: {str(e)}", )