from typing import Optional, List from datetime import datetime import json from app.infra.db.postgresql.database import Database from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse import logging logger = logging.getLogger(__name__) class AuditRepository: """审计日志数据访问层""" def __init__(self, db: Database): self.db = db async def create_log( self, user_id: Optional[int] = None, username: Optional[str] = None, action: str = "", resource_type: Optional[str] = None, resource_id: Optional[str] = None, ip_address: Optional[str] = None, user_agent: Optional[str] = None, request_method: Optional[str] = None, request_path: Optional[str] = None, request_data: Optional[dict] = None, response_status: Optional[int] = None, error_message: Optional[str] = None ) -> Optional[AuditLogResponse]: """ 创建审计日志 Args: 参数说明见 AuditLogCreate Returns: 创建的审计日志对象 """ query = """ INSERT INTO audit_logs ( user_id, username, action, resource_type, resource_id, ip_address, user_agent, request_method, request_path, request_data, response_status, error_message ) VALUES ( %(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s, %(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s, %(request_data)s, %(response_status)s, %(error_message)s ) RETURNING id, user_id, username, action, resource_type, resource_id, ip_address, user_agent, request_method, request_path, request_data, response_status, error_message, timestamp """ try: async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, { 'user_id': user_id, 'username': username, 'action': action, 'resource_type': resource_type, 'resource_id': resource_id, 'ip_address': ip_address, 'user_agent': user_agent, 'request_method': request_method, 'request_path': request_path, 'request_data': json.dumps(request_data) if request_data else None, 'response_status': response_status, 'error_message': error_message }) row = await cur.fetchone() if row: return AuditLogResponse(**row) except Exception as e: logger.error(f"Error creating audit log: {e}") raise return None async def get_logs( self, user_id: Optional[int] = None, username: Optional[str] = None, action: Optional[str] = None, resource_type: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, skip: int = 0, limit: int = 100 ) -> List[AuditLogResponse]: """ 查询审计日志 Args: user_id: 用户ID过滤 username: 用户名过滤 action: 操作类型过滤 resource_type: 资源类型过滤 start_time: 开始时间 end_time: 结束时间 skip: 跳过记录数 limit: 限制记录数 Returns: 审计日志列表 """ # 构建动态查询 conditions = [] params = {'skip': skip, 'limit': limit} if user_id is not None: conditions.append("user_id = %(user_id)s") params['user_id'] = user_id if username: conditions.append("username = %(username)s") params['username'] = username if action: conditions.append("action = %(action)s") params['action'] = action if resource_type: conditions.append("resource_type = %(resource_type)s") params['resource_type'] = resource_type if start_time: conditions.append("timestamp >= %(start_time)s") params['start_time'] = start_time if end_time: conditions.append("timestamp <= %(end_time)s") params['end_time'] = end_time where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" query = f""" SELECT id, user_id, username, action, resource_type, resource_id, ip_address, user_agent, request_method, request_path, request_data, response_status, error_message, timestamp FROM audit_logs {where_clause} ORDER BY timestamp DESC LIMIT %(limit)s OFFSET %(skip)s """ try: async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, params) rows = await cur.fetchall() return [AuditLogResponse(**row) for row in rows] except Exception as e: logger.error(f"Error querying audit logs: {e}") raise async def get_log_count( self, user_id: Optional[int] = None, username: Optional[str] = None, action: Optional[str] = None, resource_type: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None ) -> int: """ 获取审计日志数量 Args: 参数同 get_logs Returns: 日志总数 """ conditions = [] params = {} if user_id is not None: conditions.append("user_id = %(user_id)s") params['user_id'] = user_id if username: conditions.append("username = %(username)s") params['username'] = username if action: conditions.append("action = %(action)s") params['action'] = action if resource_type: conditions.append("resource_type = %(resource_type)s") params['resource_type'] = resource_type if start_time: conditions.append("timestamp >= %(start_time)s") params['start_time'] = start_time if end_time: conditions.append("timestamp <= %(end_time)s") params['end_time'] = end_time where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" query = f""" SELECT COUNT(*) as count FROM audit_logs {where_clause} """ try: async with self.db.get_connection() as conn: async with conn.cursor() as cur: await cur.execute(query, params) result = await cur.fetchone() return result['count'] if result else 0 except Exception as e: logger.error(f"Error counting audit logs: {e}") return 0