198 lines
5.2 KiB
Python
198 lines
5.2 KiB
Python
import importlib
|
|
import importlib.util
|
|
import os
|
|
import sys
|
|
import types
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
|
|
# 自动添加项目根目录到路径(处理项目结构)
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
|
|
def run_this_test(test_file):
|
|
"""自定义函数:运行单个测试文件(类似pytest)"""
|
|
pytest.main([test_file, "-v"])
|
|
|
|
|
|
def build_test_app(router, prefix: str = "") -> FastAPI:
|
|
app = FastAPI()
|
|
app.include_router(router, prefix=prefix)
|
|
return app
|
|
|
|
|
|
def load_module_from_path(module_name: str, relative_path: str):
|
|
module_path = PROJECT_ROOT / relative_path
|
|
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
assert spec and spec.loader
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def install_stub(monkeypatch, name: str, attrs: dict | None = None, package: bool = False):
|
|
module = types.ModuleType(name)
|
|
if package:
|
|
module.__path__ = []
|
|
if attrs:
|
|
for key, value in attrs.items():
|
|
setattr(module, key, value)
|
|
|
|
monkeypatch.setitem(sys.modules, name, module)
|
|
|
|
parent_name, _, child_name = name.rpartition(".")
|
|
if parent_name:
|
|
parent = sys.modules.get(parent_name)
|
|
if parent is None:
|
|
try:
|
|
parent = importlib.import_module(parent_name)
|
|
except Exception:
|
|
parent = types.ModuleType(parent_name)
|
|
parent.__path__ = []
|
|
monkeypatch.setitem(sys.modules, parent_name, parent)
|
|
setattr(parent, child_name, module)
|
|
|
|
return module
|
|
|
|
|
|
class FakeCursor:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
fetchone_results=None,
|
|
fetchall_results=None,
|
|
rowcount: int = 0,
|
|
rowcounts=None,
|
|
):
|
|
self._fetchone_results = list(fetchone_results or [])
|
|
self._fetchall_results = list(fetchall_results or [])
|
|
self._rowcounts = list(rowcounts or [])
|
|
self.rowcount = rowcount
|
|
self.executed: list[tuple[str, dict | tuple | None]] = []
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
async def execute(self, query, params=None):
|
|
self.executed.append((str(query), params))
|
|
if self._rowcounts:
|
|
self.rowcount = self._rowcounts.pop(0)
|
|
|
|
async def fetchone(self):
|
|
if self._fetchone_results:
|
|
return self._fetchone_results.pop(0)
|
|
return None
|
|
|
|
async def fetchall(self):
|
|
if self._fetchall_results:
|
|
return self._fetchall_results.pop(0)
|
|
return []
|
|
|
|
|
|
class FakeConnection:
|
|
def __init__(self, cursor: FakeCursor):
|
|
self.cursor_instance = cursor
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def cursor(self):
|
|
return self.cursor_instance
|
|
|
|
|
|
class FakeDB:
|
|
def __init__(self, cursor: FakeCursor):
|
|
self.connection = FakeConnection(cursor)
|
|
|
|
def get_connection(self):
|
|
return self.connection
|
|
|
|
|
|
class FakeExecuteResult:
|
|
def __init__(self, *, rows=None, scalar_value=None):
|
|
self._rows = list(rows or [])
|
|
self._scalar_value = scalar_value
|
|
|
|
def scalars(self):
|
|
return self
|
|
|
|
def all(self):
|
|
return self._rows
|
|
|
|
def scalar(self):
|
|
return self._scalar_value
|
|
|
|
|
|
class FakeAsyncSession:
|
|
def __init__(self, execute_results=None):
|
|
self._execute_results = list(execute_results or [])
|
|
self.executed = []
|
|
self.added = []
|
|
self.commit_count = 0
|
|
self.refreshed = []
|
|
|
|
def add(self, obj):
|
|
self.added.append(obj)
|
|
|
|
async def execute(self, stmt):
|
|
self.executed.append(stmt)
|
|
if self._execute_results:
|
|
return self._execute_results.pop(0)
|
|
return FakeExecuteResult()
|
|
|
|
async def commit(self):
|
|
self.commit_count += 1
|
|
|
|
async def refresh(self, obj):
|
|
self.refreshed.append(obj)
|
|
|
|
|
|
def make_user(**overrides):
|
|
from app.domain.models.role import UserRole
|
|
from app.domain.schemas.user import UserInDB
|
|
|
|
data = {
|
|
"id": 1,
|
|
"username": "tester",
|
|
"email": "tester@example.com",
|
|
"hashed_password": "hashed-password",
|
|
"role": UserRole.USER,
|
|
"is_active": True,
|
|
"is_superuser": False,
|
|
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
"updated_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
data.update(overrides)
|
|
return UserInDB(**data)
|
|
|
|
|
|
def make_audit_log(**overrides):
|
|
data = {
|
|
"id": uuid4(),
|
|
"user_id": uuid4(),
|
|
"project_id": uuid4(),
|
|
"action": "LOGIN",
|
|
"resource_type": "user",
|
|
"resource_id": "1",
|
|
"ip_address": "127.0.0.1",
|
|
"request_method": "GET",
|
|
"request_path": "/audit/logs",
|
|
"request_data": {"ok": True},
|
|
"response_status": 200,
|
|
"timestamp": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
|
}
|
|
data.update(overrides)
|
|
return SimpleNamespace(**data)
|