Files
TJWaterServerBinary/app/infra/db/dynamic_manager.py

209 lines
6.8 KiB
Python

import asyncio
import logging
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict
from uuid import UUID
from psycopg_pool import AsyncConnectionPool
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from app.core.config import settings
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class PgEngineEntry:
engine: AsyncEngine
sessionmaker: async_sessionmaker[AsyncSession]
@dataclass(frozen=True)
class CacheKey:
project_id: UUID
db_role: str
class ProjectConnectionManager:
def __init__(self) -> None:
self._pg_cache: Dict[CacheKey, PgEngineEntry] = OrderedDict()
self._ts_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._pg_raw_cache: Dict[CacheKey, AsyncConnectionPool] = OrderedDict()
self._pg_lock = asyncio.Lock()
self._ts_lock = asyncio.Lock()
self._pg_raw_lock = asyncio.Lock()
def _normalize_pg_url(self, url: str) -> str:
parsed = make_url(url)
if parsed.drivername == "postgresql":
parsed = parsed.set(drivername="postgresql+psycopg")
return str(parsed)
async def get_pg_sessionmaker(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> async_sessionmaker[AsyncSession]:
async with self._pg_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
entry = self._pg_cache.get(key)
if entry:
self._pg_cache.move_to_end(key)
return entry.sessionmaker
normalized_url = self._normalize_pg_url(connection_url)
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
engine = create_async_engine(
normalized_url,
pool_size=pool_min_size,
max_overflow=max(0, pool_max_size - pool_min_size),
pool_pre_ping=True,
)
sessionmaker = async_sessionmaker(engine, expire_on_commit=False)
self._pg_cache[key] = PgEngineEntry(
engine=engine,
sessionmaker=sessionmaker,
)
await self._evict_pg_if_needed()
logger.info(
"Created PostgreSQL engine for project %s (%s)", project_id, db_role
)
return sessionmaker
async def get_timescale_pool(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._ts_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._ts_cache.get(key)
if pool:
self._ts_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
max_size=pool_max_size,
open=False,
)
await pool.open()
self._ts_cache[key] = pool
await self._evict_ts_if_needed()
logger.info(
"Created TimescaleDB pool for project %s (%s)", project_id, db_role
)
return pool
async def get_pg_pool(
self,
project_id: UUID,
db_role: str,
connection_url: str,
pool_min_size: int,
pool_max_size: int,
) -> AsyncConnectionPool:
async with self._pg_raw_lock:
key = CacheKey(project_id=project_id, db_role=db_role)
pool = self._pg_raw_cache.get(key)
if pool:
self._pg_raw_cache.move_to_end(key)
return pool
pool_min_size = max(1, pool_min_size)
pool_max_size = max(pool_min_size, pool_max_size)
pool = AsyncConnectionPool(
conninfo=connection_url,
min_size=pool_min_size,
max_size=pool_max_size,
open=False,
)
await pool.open()
self._pg_raw_cache[key] = pool
await self._evict_pg_raw_if_needed()
logger.info(
"Created PostgreSQL pool for project %s (%s)", project_id, db_role
)
return pool
async def _evict_pg_if_needed(self) -> None:
while len(self._pg_cache) > settings.PROJECT_PG_CACHE_SIZE:
key, entry = self._pg_cache.popitem(last=False)
await entry.engine.dispose()
logger.info(
"Evicted PostgreSQL engine for project %s (%s)",
key.project_id,
key.db_role,
)
async def _evict_ts_if_needed(self) -> None:
while len(self._ts_cache) > settings.PROJECT_TS_CACHE_SIZE:
key, pool = self._ts_cache.popitem(last=False)
await pool.close()
logger.info(
"Evicted TimescaleDB pool for project %s (%s)",
key.project_id,
key.db_role,
)
async def _evict_pg_raw_if_needed(self) -> None:
while len(self._pg_raw_cache) > settings.PROJECT_PG_CACHE_SIZE:
key, pool = self._pg_raw_cache.popitem(last=False)
await pool.close()
logger.info(
"Evicted PostgreSQL pool for project %s (%s)",
key.project_id,
key.db_role,
)
async def close_all(self) -> None:
async with self._pg_lock:
for key, entry in list(self._pg_cache.items()):
await entry.engine.dispose()
logger.info(
"Closed PostgreSQL engine for project %s (%s)",
key.project_id,
key.db_role,
)
self._pg_cache.clear()
async with self._ts_lock:
for key, pool in list(self._ts_cache.items()):
await pool.close()
logger.info(
"Closed TimescaleDB pool for project %s (%s)",
key.project_id,
key.db_role,
)
self._ts_cache.clear()
async with self._pg_raw_lock:
for key, pool in list(self._pg_raw_cache.items()):
await pool.close()
logger.info(
"Closed PostgreSQL pool for project %s (%s)",
key.project_id,
key.db_role,
)
self._pg_raw_cache.clear()
project_connection_manager = ProjectConnectionManager()