152 lines
4.8 KiB
Python
152 lines
4.8 KiB
Python
from fastmcp import FastMCP, Context
|
|
from typing import Optional, Dict, Any
|
|
|
|
from ..postgresql.database import get_database_instance
|
|
from ..postgresql.scada_info import ScadaRepository
|
|
from ..postgresql.scheme import SchemeRepository
|
|
|
|
# 创建 MCP 服务器实例
|
|
mcp = FastMCP("TJWater PostgreSQL Service", description="访问水务系统 PostgreSQL 数据库操作")
|
|
|
|
|
|
# 数据库连接辅助函数
|
|
async def get_database_connection(db_name: Optional[str] = None, ctx: Context = None):
|
|
"""获取数据库连接,支持通过参数指定数据库名称"""
|
|
if ctx:
|
|
await ctx.info(f"连接到数据库: {db_name or '默认数据库'}")
|
|
|
|
instance = await get_database_instance(db_name)
|
|
async with instance.get_connection() as conn:
|
|
yield conn
|
|
|
|
|
|
@mcp.tool
|
|
async def get_scada_info(db_name: Optional[str] = None, ctx: Context = None) -> Dict[str, Any]:
|
|
"""
|
|
查询所有 SCADA 信息
|
|
|
|
Args:
|
|
db_name: 可选的数据库名称,为空时使用默认数据库
|
|
ctx: MCP 上下文,用于日志记录
|
|
"""
|
|
try:
|
|
if ctx:
|
|
await ctx.info("查询 SCADA 信息...")
|
|
|
|
async for conn in get_database_connection(db_name, ctx):
|
|
scada_data = await ScadaRepository.get_scadas(conn)
|
|
|
|
if ctx:
|
|
await ctx.info(f"检索到 {len(scada_data)} 条 SCADA 记录")
|
|
|
|
return {"success": True, "data": scada_data, "count": len(scada_data)}
|
|
except Exception as e:
|
|
error_msg = f"查询 SCADA 信息时发生错误: {str(e)}"
|
|
if ctx:
|
|
await ctx.error(error_msg)
|
|
return {"success": False, "error": error_msg}
|
|
|
|
|
|
@mcp.tool
|
|
async def get_scheme_list(db_name: Optional[str] = None, ctx: Context = None) -> Dict[str, Any]:
|
|
"""
|
|
查询所有方案信息
|
|
|
|
Args:
|
|
db_name: 可选的数据库名称,为空时使用默认数据库
|
|
ctx: MCP 上下文,用于日志记录
|
|
"""
|
|
try:
|
|
if ctx:
|
|
await ctx.info("查询方案信息...")
|
|
|
|
async for conn in get_database_connection(db_name, ctx):
|
|
scheme_data = await SchemeRepository.get_schemes(conn)
|
|
|
|
if ctx:
|
|
await ctx.info(f"检索到 {len(scheme_data)} 条方案记录")
|
|
|
|
return {"success": True, "data": scheme_data, "count": len(scheme_data)}
|
|
except Exception as e:
|
|
error_msg = f"查询方案信息时发生错误: {str(e)}"
|
|
if ctx:
|
|
await ctx.error(error_msg)
|
|
return {"success": False, "error": error_msg}
|
|
|
|
|
|
@mcp.tool
|
|
async def get_burst_locate_results(db_name: Optional[str] = None, ctx: Context = None) -> Dict[str, Any]:
|
|
"""
|
|
查询所有爆管定位结果
|
|
|
|
Args:
|
|
db_name: 可选的数据库名称,为空时使用默认数据库
|
|
ctx: MCP 上下文,用于日志记录
|
|
"""
|
|
try:
|
|
if ctx:
|
|
await ctx.info("查询爆管定位结果...")
|
|
|
|
async for conn in get_database_connection(db_name, ctx):
|
|
burst_data = await SchemeRepository.get_burst_locate_results(conn)
|
|
|
|
if ctx:
|
|
await ctx.info(f"检索到 {len(burst_data)} 条爆管记录")
|
|
|
|
return {"success": True, "data": burst_data, "count": len(burst_data)}
|
|
except Exception as e:
|
|
error_msg = f"查询爆管定位结果时发生错误: {str(e)}"
|
|
if ctx:
|
|
await ctx.error(error_msg)
|
|
return {"success": False, "error": error_msg}
|
|
|
|
|
|
@mcp.tool
|
|
async def get_burst_locate_result_by_incident(
|
|
burst_incident: str,
|
|
db_name: Optional[str] = None,
|
|
ctx: Context = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
根据 burst_incident 查询爆管定位结果
|
|
|
|
Args:
|
|
burst_incident: 爆管事件标识符
|
|
db_name: 可选的数据库名称,为空时使用默认数据库
|
|
ctx: MCP 上下文,用于日志记录
|
|
"""
|
|
try:
|
|
if ctx:
|
|
await ctx.info(f"查询爆管事件 {burst_incident} 的结果...")
|
|
|
|
async for conn in get_database_connection(db_name, ctx):
|
|
result = await SchemeRepository.get_burst_locate_result_by_incident(
|
|
conn, burst_incident
|
|
)
|
|
|
|
if ctx:
|
|
await ctx.info("检索到爆管事件数据")
|
|
|
|
return result
|
|
except Exception as e:
|
|
error_msg = f"根据 burst_incident 查询爆管定位结果时发生错误: {str(e)}"
|
|
if ctx:
|
|
await ctx.error(error_msg)
|
|
return {"success": False, "error": error_msg}
|
|
|
|
|
|
# 添加静态配置资源
|
|
@mcp.resource("config://database/supported_databases")
|
|
def get_supported_databases():
|
|
"""列出支持的数据库配置"""
|
|
return ["default", "backup", "analytics"]
|
|
|
|
|
|
@mcp.resource("config://api/version")
|
|
def get_api_version():
|
|
"""获取 API 版本"""
|
|
return "1.0.0"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mcp.run() |