221 lines
7.5 KiB
Python
221 lines
7.5 KiB
Python
from typing import Optional, List
|
|
from datetime import datetime
|
|
import json
|
|
from app.infra.db.postgresql.database import Database
|
|
from app.domain.schemas.audit import AuditLogCreate, AuditLogResponse
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class AuditRepository:
|
|
"""审计日志数据访问层"""
|
|
|
|
def __init__(self, db: Database):
|
|
self.db = db
|
|
|
|
async def create_log(
|
|
self,
|
|
user_id: Optional[int] = None,
|
|
username: Optional[str] = None,
|
|
action: str = "",
|
|
resource_type: Optional[str] = None,
|
|
resource_id: Optional[str] = None,
|
|
ip_address: Optional[str] = None,
|
|
user_agent: Optional[str] = None,
|
|
request_method: Optional[str] = None,
|
|
request_path: Optional[str] = None,
|
|
request_data: Optional[dict] = None,
|
|
response_status: Optional[int] = None,
|
|
error_message: Optional[str] = None
|
|
) -> Optional[AuditLogResponse]:
|
|
"""
|
|
创建审计日志
|
|
|
|
Args:
|
|
参数说明见 AuditLogCreate
|
|
|
|
Returns:
|
|
创建的审计日志对象
|
|
"""
|
|
query = """
|
|
INSERT INTO audit_logs (
|
|
user_id, username, action, resource_type, resource_id,
|
|
ip_address, user_agent, request_method, request_path,
|
|
request_data, response_status, error_message
|
|
)
|
|
VALUES (
|
|
%(user_id)s, %(username)s, %(action)s, %(resource_type)s, %(resource_id)s,
|
|
%(ip_address)s, %(user_agent)s, %(request_method)s, %(request_path)s,
|
|
%(request_data)s, %(response_status)s, %(error_message)s
|
|
)
|
|
RETURNING id, user_id, username, action, resource_type, resource_id,
|
|
ip_address, user_agent, request_method, request_path,
|
|
request_data, response_status, error_message, timestamp
|
|
"""
|
|
|
|
try:
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, {
|
|
'user_id': user_id,
|
|
'username': username,
|
|
'action': action,
|
|
'resource_type': resource_type,
|
|
'resource_id': resource_id,
|
|
'ip_address': ip_address,
|
|
'user_agent': user_agent,
|
|
'request_method': request_method,
|
|
'request_path': request_path,
|
|
'request_data': json.dumps(request_data) if request_data else None,
|
|
'response_status': response_status,
|
|
'error_message': error_message
|
|
})
|
|
row = await cur.fetchone()
|
|
if row:
|
|
return AuditLogResponse(**row)
|
|
except Exception as e:
|
|
logger.error(f"Error creating audit log: {e}")
|
|
raise
|
|
|
|
return None
|
|
|
|
async def get_logs(
|
|
self,
|
|
user_id: Optional[int] = None,
|
|
username: Optional[str] = None,
|
|
action: Optional[str] = None,
|
|
resource_type: Optional[str] = None,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[AuditLogResponse]:
|
|
"""
|
|
查询审计日志
|
|
|
|
Args:
|
|
user_id: 用户ID过滤
|
|
username: 用户名过滤
|
|
action: 操作类型过滤
|
|
resource_type: 资源类型过滤
|
|
start_time: 开始时间
|
|
end_time: 结束时间
|
|
skip: 跳过记录数
|
|
limit: 限制记录数
|
|
|
|
Returns:
|
|
审计日志列表
|
|
"""
|
|
# 构建动态查询
|
|
conditions = []
|
|
params = {'skip': skip, 'limit': limit}
|
|
|
|
if user_id is not None:
|
|
conditions.append("user_id = %(user_id)s")
|
|
params['user_id'] = user_id
|
|
|
|
if username:
|
|
conditions.append("username = %(username)s")
|
|
params['username'] = username
|
|
|
|
if action:
|
|
conditions.append("action = %(action)s")
|
|
params['action'] = action
|
|
|
|
if resource_type:
|
|
conditions.append("resource_type = %(resource_type)s")
|
|
params['resource_type'] = resource_type
|
|
|
|
if start_time:
|
|
conditions.append("timestamp >= %(start_time)s")
|
|
params['start_time'] = start_time
|
|
|
|
if end_time:
|
|
conditions.append("timestamp <= %(end_time)s")
|
|
params['end_time'] = end_time
|
|
|
|
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
|
|
|
query = f"""
|
|
SELECT id, user_id, username, action, resource_type, resource_id,
|
|
ip_address, user_agent, request_method, request_path,
|
|
request_data, response_status, error_message, timestamp
|
|
FROM audit_logs
|
|
{where_clause}
|
|
ORDER BY timestamp DESC
|
|
LIMIT %(limit)s OFFSET %(skip)s
|
|
"""
|
|
|
|
try:
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, params)
|
|
rows = await cur.fetchall()
|
|
return [AuditLogResponse(**row) for row in rows]
|
|
except Exception as e:
|
|
logger.error(f"Error querying audit logs: {e}")
|
|
raise
|
|
|
|
async def get_log_count(
|
|
self,
|
|
user_id: Optional[int] = None,
|
|
username: Optional[str] = None,
|
|
action: Optional[str] = None,
|
|
resource_type: Optional[str] = None,
|
|
start_time: Optional[datetime] = None,
|
|
end_time: Optional[datetime] = None
|
|
) -> int:
|
|
"""
|
|
获取审计日志数量
|
|
|
|
Args:
|
|
参数同 get_logs
|
|
|
|
Returns:
|
|
日志总数
|
|
"""
|
|
conditions = []
|
|
params = {}
|
|
|
|
if user_id is not None:
|
|
conditions.append("user_id = %(user_id)s")
|
|
params['user_id'] = user_id
|
|
|
|
if username:
|
|
conditions.append("username = %(username)s")
|
|
params['username'] = username
|
|
|
|
if action:
|
|
conditions.append("action = %(action)s")
|
|
params['action'] = action
|
|
|
|
if resource_type:
|
|
conditions.append("resource_type = %(resource_type)s")
|
|
params['resource_type'] = resource_type
|
|
|
|
if start_time:
|
|
conditions.append("timestamp >= %(start_time)s")
|
|
params['start_time'] = start_time
|
|
|
|
if end_time:
|
|
conditions.append("timestamp <= %(end_time)s")
|
|
params['end_time'] = end_time
|
|
|
|
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
|
|
|
query = f"""
|
|
SELECT COUNT(*) as count
|
|
FROM audit_logs
|
|
{where_clause}
|
|
"""
|
|
|
|
try:
|
|
async with self.db.get_connection() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(query, params)
|
|
result = await cur.fetchone()
|
|
return result['count'] if result else 0
|
|
except Exception as e:
|
|
logger.error(f"Error counting audit logs: {e}")
|
|
return 0
|