Files
TJWaterServerBinary/tests/conftest.py
T

198 lines
5.2 KiB
Python

import importlib
import importlib.util
import os
import sys
import types
from datetime import datetime, timezone
from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
import pytest
from fastapi import FastAPI
# 自动添加项目根目录到路径(处理项目结构)
PROJECT_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PROJECT_ROOT))
def run_this_test(test_file):
"""自定义函数:运行单个测试文件(类似pytest)"""
pytest.main([test_file, "-v"])
def build_test_app(router, prefix: str = "") -> FastAPI:
app = FastAPI()
app.include_router(router, prefix=prefix)
return app
def load_module_from_path(module_name: str, relative_path: str):
module_path = PROJECT_ROOT / relative_path
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module)
return module
def install_stub(monkeypatch, name: str, attrs: dict | None = None, package: bool = False):
module = types.ModuleType(name)
if package:
module.__path__ = []
if attrs:
for key, value in attrs.items():
setattr(module, key, value)
monkeypatch.setitem(sys.modules, name, module)
parent_name, _, child_name = name.rpartition(".")
if parent_name:
parent = sys.modules.get(parent_name)
if parent is None:
try:
parent = importlib.import_module(parent_name)
except Exception:
parent = types.ModuleType(parent_name)
parent.__path__ = []
monkeypatch.setitem(sys.modules, parent_name, parent)
setattr(parent, child_name, module)
return module
class FakeCursor:
def __init__(
self,
*,
fetchone_results=None,
fetchall_results=None,
rowcount: int = 0,
rowcounts=None,
):
self._fetchone_results = list(fetchone_results or [])
self._fetchall_results = list(fetchall_results or [])
self._rowcounts = list(rowcounts or [])
self.rowcount = rowcount
self.executed: list[tuple[str, dict | tuple | None]] = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def execute(self, query, params=None):
self.executed.append((str(query), params))
if self._rowcounts:
self.rowcount = self._rowcounts.pop(0)
async def fetchone(self):
if self._fetchone_results:
return self._fetchone_results.pop(0)
return None
async def fetchall(self):
if self._fetchall_results:
return self._fetchall_results.pop(0)
return []
class FakeConnection:
def __init__(self, cursor: FakeCursor):
self.cursor_instance = cursor
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
def cursor(self):
return self.cursor_instance
class FakeDB:
def __init__(self, cursor: FakeCursor):
self.connection = FakeConnection(cursor)
def get_connection(self):
return self.connection
class FakeExecuteResult:
def __init__(self, *, rows=None, scalar_value=None):
self._rows = list(rows or [])
self._scalar_value = scalar_value
def scalars(self):
return self
def all(self):
return self._rows
def scalar(self):
return self._scalar_value
class FakeAsyncSession:
def __init__(self, execute_results=None):
self._execute_results = list(execute_results or [])
self.executed = []
self.added = []
self.commit_count = 0
self.refreshed = []
def add(self, obj):
self.added.append(obj)
async def execute(self, stmt):
self.executed.append(stmt)
if self._execute_results:
return self._execute_results.pop(0)
return FakeExecuteResult()
async def commit(self):
self.commit_count += 1
async def refresh(self, obj):
self.refreshed.append(obj)
def make_user(**overrides):
from app.domain.models.role import UserRole
from app.domain.schemas.user import UserInDB
data = {
"id": 1,
"username": "tester",
"email": "tester@example.com",
"hashed_password": "hashed-password",
"role": UserRole.USER,
"is_active": True,
"is_superuser": False,
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
data.update(overrides)
return UserInDB(**data)
def make_audit_log(**overrides):
data = {
"id": uuid4(),
"user_id": uuid4(),
"project_id": uuid4(),
"action": "LOGIN",
"resource_type": "user",
"resource_id": "1",
"ip_address": "127.0.0.1",
"request_method": "GET",
"request_path": "/audit/logs",
"request_data": {"ok": True},
"response_status": 200,
"timestamp": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
data.update(overrides)
return SimpleNamespace(**data)