重构数据库连接管理,添加元数据支持
This commit is contained in:
@@ -1,220 +1,112 @@
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
from app.infra.db.postgresql.database import Database
|
||||
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.schemas.audit import AuditLogResponse
|
||||
from app.infra.db.metadata import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AuditRepository:
|
||||
"""审计日志数据访问层"""
|
||||
|
||||
def __init__(self, db: Database):
|
||||
self.db = db
|
||||
|
||||
"""审计日志数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def create_log(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
action: str = "",
|
||||
action: str,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
resource_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_method: Optional[str] = None,
|
||||
request_path: Optional[str] = None,
|
||||
request_data: Optional[dict] = None,
|
||||
response_status: Optional[int] = None,
|
||||
error_message: Optional[str] = None
|
||||
) -> Optional[AuditLogResponse]:
|
||||
"""
|
||||
创建审计日志
|
||||
|
||||
Args:
|
||||
参数说明见 AuditLogCreate
|
||||
|
||||
Returns:
|
||||
创建的审计日志对象
|
||||
"""
|
||||
query = """
|
||||
INSERT INTO audit_logs (
|
||||
user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message
|
||||
)
|
||||
VALUES (
|
||||
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
|
||||
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
|
||||
%(request_data)s, %(response_status)s, %(error_message)s
|
||||
)
|
||||
RETURNING id, user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message, timestamp
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, {
|
||||
'user_id': user_id,
|
||||
'username': username,
|
||||
'action': action,
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'ip_address': ip_address,
|
||||
'user_agent': user_agent,
|
||||
'request_method': request_method,
|
||||
'request_path': request_path,
|
||||
'request_data': json.dumps(request_data) if request_data else None,
|
||||
'response_status': response_status,
|
||||
'error_message': error_message
|
||||
})
|
||||
row = await cur.fetchone()
|
||||
if row:
|
||||
return AuditLogResponse(**row)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating audit log: {e}")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
) -> AuditLogResponse:
|
||||
log = models.AuditLog(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action=action,
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
ip_address=ip_address,
|
||||
request_method=request_method,
|
||||
request_path=request_path,
|
||||
request_data=request_data,
|
||||
response_status=response_status,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
self.session.add(log)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(log)
|
||||
return AuditLogResponse.model_validate(log)
|
||||
|
||||
async def get_logs(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
limit: int = 100,
|
||||
) -> List[AuditLogResponse]:
|
||||
"""
|
||||
查询审计日志
|
||||
|
||||
Args:
|
||||
user_id: 用户ID过滤
|
||||
username: 用户名过滤
|
||||
action: 操作类型过滤
|
||||
resource_type: 资源类型过滤
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
skip: 跳过记录数
|
||||
limit: 限制记录数
|
||||
|
||||
Returns:
|
||||
审计日志列表
|
||||
"""
|
||||
# 构建动态查询
|
||||
conditions = []
|
||||
params = {'skip': skip, 'limit': limit}
|
||||
|
||||
if user_id is not None:
|
||||
conditions.append("user_id = %(user_id)s")
|
||||
params['user_id'] = user_id
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append("action = %(action)s")
|
||||
params['action'] = action
|
||||
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append("resource_type = %(resource_type)s")
|
||||
params['resource_type'] = resource_type
|
||||
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append("timestamp >= %(start_time)s")
|
||||
params['start_time'] = start_time
|
||||
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= %(end_time)s")
|
||||
params['end_time'] = end_time
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT id, user_id, username, action, resource_type, resource_id,
|
||||
ip_address, user_agent, request_method, request_path,
|
||||
request_data, response_status, error_message, timestamp
|
||||
FROM audit_logs
|
||||
{where_clause}
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT %(limit)s OFFSET %(skip)s
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
rows = await cur.fetchall()
|
||||
return [AuditLogResponse(**row) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying audit logs: {e}")
|
||||
raise
|
||||
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = (
|
||||
select(models.AuditLog)
|
||||
.where(*conditions)
|
||||
.order_by(models.AuditLog.timestamp.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
AuditLogResponse.model_validate(log)
|
||||
for log in result.scalars().all()
|
||||
]
|
||||
|
||||
async def get_log_count(
|
||||
self,
|
||||
user_id: Optional[int] = None,
|
||||
username: Optional[str] = None,
|
||||
user_id: Optional[UUID] = None,
|
||||
project_id: Optional[UUID] = None,
|
||||
action: Optional[str] = None,
|
||||
resource_type: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None,
|
||||
) -> int:
|
||||
"""
|
||||
获取审计日志数量
|
||||
|
||||
Args:
|
||||
参数同 get_logs
|
||||
|
||||
Returns:
|
||||
日志总数
|
||||
"""
|
||||
conditions = []
|
||||
params = {}
|
||||
|
||||
if user_id is not None:
|
||||
conditions.append("user_id = %(user_id)s")
|
||||
params['user_id'] = user_id
|
||||
|
||||
if username:
|
||||
conditions.append("username = %(username)s")
|
||||
params['username'] = username
|
||||
|
||||
conditions.append(models.AuditLog.user_id == user_id)
|
||||
if project_id is not None:
|
||||
conditions.append(models.AuditLog.project_id == project_id)
|
||||
if action:
|
||||
conditions.append("action = %(action)s")
|
||||
params['action'] = action
|
||||
|
||||
conditions.append(models.AuditLog.action == action)
|
||||
if resource_type:
|
||||
conditions.append("resource_type = %(resource_type)s")
|
||||
params['resource_type'] = resource_type
|
||||
|
||||
conditions.append(models.AuditLog.resource_type == resource_type)
|
||||
if start_time:
|
||||
conditions.append("timestamp >= %(start_time)s")
|
||||
params['start_time'] = start_time
|
||||
|
||||
conditions.append(models.AuditLog.timestamp >= start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= %(end_time)s")
|
||||
params['end_time'] = end_time
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM audit_logs
|
||||
{where_clause}
|
||||
"""
|
||||
|
||||
try:
|
||||
async with self.db.get_connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
await cur.execute(query, params)
|
||||
result = await cur.fetchone()
|
||||
return result['count'] if result else 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting audit logs: {e}")
|
||||
return 0
|
||||
conditions.append(models.AuditLog.timestamp <= end_time)
|
||||
|
||||
stmt = select(func.count()).select_from(models.AuditLog).where(*conditions)
|
||||
result = await self.session.execute(stmt)
|
||||
return int(result.scalar() or 0)
|
||||
|
||||
164
app/infra/repositories/metadata_repository.py
Normal file
164
app/infra/repositories/metadata_repository.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.encryption import get_encryptor
|
||||
from app.infra.db.metadata import models
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectDbRouting:
|
||||
project_id: UUID
|
||||
db_role: str
|
||||
db_type: str
|
||||
dsn: str
|
||||
pool_min_size: int
|
||||
pool_max_size: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectGeoServerInfo:
|
||||
project_id: UUID
|
||||
gs_base_url: Optional[str]
|
||||
gs_admin_user: Optional[str]
|
||||
gs_admin_password: Optional[str]
|
||||
gs_datastore_name: str
|
||||
default_extent: Optional[dict]
|
||||
srid: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProjectSummary:
|
||||
project_id: UUID
|
||||
name: str
|
||||
code: str
|
||||
description: Optional[str]
|
||||
gs_workspace: str
|
||||
status: str
|
||||
project_role: str
|
||||
|
||||
|
||||
class MetadataRepository:
|
||||
"""元数据访问层(system_hub)"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_user_by_keycloak_id(
|
||||
self, keycloak_id: UUID
|
||||
) -> Optional[models.User]:
|
||||
result = await self.session.execute(
|
||||
select(models.User).where(models.User.keycloak_id == keycloak_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_by_id(self, project_id: UUID) -> Optional[models.Project]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).where(models.Project.id == project_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_membership_role(
|
||||
self, project_id: UUID, user_id: UUID
|
||||
) -> Optional[str]:
|
||||
result = await self.session.execute(
|
||||
select(models.UserProjectMembership.project_role).where(
|
||||
models.UserProjectMembership.project_id == project_id,
|
||||
models.UserProjectMembership.user_id == user_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_project_db_routing(
|
||||
self, project_id: UUID, db_role: str
|
||||
) -> Optional[ProjectDbRouting]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectDatabase).where(
|
||||
models.ProjectDatabase.project_id == project_id,
|
||||
models.ProjectDatabase.db_role == db_role,
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
encryptor = get_encryptor()
|
||||
dsn = encryptor.decrypt(record.dsn_encrypted)
|
||||
return ProjectDbRouting(
|
||||
project_id=record.project_id,
|
||||
db_role=record.db_role,
|
||||
db_type=record.db_type,
|
||||
dsn=dsn,
|
||||
pool_min_size=record.pool_min_size,
|
||||
pool_max_size=record.pool_max_size,
|
||||
)
|
||||
|
||||
async def get_geoserver_config(
|
||||
self, project_id: UUID
|
||||
) -> Optional[ProjectGeoServerInfo]:
|
||||
result = await self.session.execute(
|
||||
select(models.ProjectGeoServerConfig).where(
|
||||
models.ProjectGeoServerConfig.project_id == project_id
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return None
|
||||
encryptor = get_encryptor()
|
||||
password = (
|
||||
encryptor.decrypt(record.gs_admin_password_encrypted)
|
||||
if record.gs_admin_password_encrypted
|
||||
else None
|
||||
)
|
||||
return ProjectGeoServerInfo(
|
||||
project_id=record.project_id,
|
||||
gs_base_url=record.gs_base_url,
|
||||
gs_admin_user=record.gs_admin_user,
|
||||
gs_admin_password=password,
|
||||
gs_datastore_name=record.gs_datastore_name,
|
||||
default_extent=record.default_extent,
|
||||
srid=record.srid,
|
||||
)
|
||||
|
||||
async def list_projects_for_user(self, user_id: UUID) -> List[ProjectSummary]:
|
||||
stmt = (
|
||||
select(models.Project, models.UserProjectMembership.project_role)
|
||||
.join(
|
||||
models.UserProjectMembership,
|
||||
models.UserProjectMembership.project_id == models.Project.id,
|
||||
)
|
||||
.where(models.UserProjectMembership.user_id == user_id)
|
||||
.order_by(models.Project.name)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role=role,
|
||||
)
|
||||
for project, role in result.all()
|
||||
]
|
||||
|
||||
async def list_all_projects(self) -> List[ProjectSummary]:
|
||||
result = await self.session.execute(
|
||||
select(models.Project).order_by(models.Project.name)
|
||||
)
|
||||
return [
|
||||
ProjectSummary(
|
||||
project_id=project.id,
|
||||
name=project.name,
|
||||
code=project.code,
|
||||
description=project.description,
|
||||
gs_workspace=project.gs_workspace,
|
||||
status=project.status,
|
||||
project_role="owner",
|
||||
)
|
||||
for project in result.scalars().all()
|
||||
]
|
||||
Reference in New Issue
Block a user