新增 API 测试用例,修复失效接口问题

This commit is contained in:
2026-05-21 15:32:12 +08:00
parent 751950e5b5
commit 2317f4d527
15 changed files with 1486 additions and 96 deletions
+91
View File
@@ -0,0 +1,91 @@
from unittest.mock import AsyncMock
from fastapi.testclient import TestClient
from app.api.v1.endpoints import audit as audit_endpoint
from app.auth.metadata_dependencies import (
get_current_metadata_admin,
get_current_metadata_user,
)
from tests.conftest import build_test_app, make_audit_log
def _build_client(repo, *, metadata_admin=None, metadata_user=None) -> TestClient:
app = build_test_app(audit_endpoint.router, "/audit")
app.dependency_overrides[audit_endpoint.get_audit_repository] = lambda: repo
if metadata_admin is not None:
app.dependency_overrides[get_current_metadata_admin] = lambda: metadata_admin
if metadata_user is not None:
app.dependency_overrides[get_current_metadata_user] = lambda: metadata_user
return TestClient(app)
def test_get_audit_logs_passes_filters():
repo = type(
"Repo",
(),
{
"get_logs": AsyncMock(return_value=[make_audit_log(action="LOGIN")]),
"get_log_count": AsyncMock(),
},
)()
client = _build_client(repo, metadata_admin=object())
response = client.get(
"/audit/logs",
params={
"action": "LOGIN",
"resource_type": "user",
"skip": 2,
"limit": 5,
},
)
assert response.status_code == 200
assert response.json()[0]["action"] == "LOGIN"
repo.get_logs.assert_awaited_once()
kwargs = repo.get_logs.await_args.kwargs
assert kwargs["action"] == "LOGIN"
assert kwargs["resource_type"] == "user"
assert kwargs["skip"] == 2
assert kwargs["limit"] == 5
def test_get_audit_logs_count_returns_count_payload():
repo = type(
"Repo",
(),
{
"get_logs": AsyncMock(),
"get_log_count": AsyncMock(return_value=7),
},
)()
client = _build_client(repo, metadata_admin=object())
response = client.get("/audit/logs/count", params={"action": "DELETE_USER"})
assert response.status_code == 200
assert response.json() == {"count": 7}
repo.get_log_count.assert_awaited_once()
assert repo.get_log_count.await_args.kwargs["action"] == "DELETE_USER"
def test_get_my_audit_logs_forces_current_user_id():
current_user = type("User", (), {"id": make_audit_log().user_id})()
repo = type(
"Repo",
(),
{
"get_logs": AsyncMock(return_value=[make_audit_log(user_id=current_user.id)]),
"get_log_count": AsyncMock(),
},
)()
client = _build_client(repo, metadata_user=current_user)
response = client.get("/audit/logs/my", params={"limit": 3})
assert response.status_code == 200
repo.get_logs.assert_awaited_once()
kwargs = repo.get_logs.await_args.kwargs
assert kwargs["user_id"] == current_user.id
assert kwargs["limit"] == 3
+139
View File
@@ -0,0 +1,139 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from fastapi.testclient import TestClient
from app.api.v1.endpoints import auth as auth_endpoint
from app.auth.dependencies import get_current_active_user, get_user_repository
from app.core.security import create_access_token, create_refresh_token, get_password_hash
from tests.conftest import build_test_app, make_user
def _build_client(repo, current_user=None) -> TestClient:
app = build_test_app(auth_endpoint.router, "/api/v1/auth")
app.dependency_overrides[get_user_repository] = lambda: repo
if current_user is not None:
app.dependency_overrides[get_current_active_user] = lambda: current_user
return TestClient(app)
def test_register_success():
repo = SimpleNamespace(
user_exists=AsyncMock(side_effect=[False, False]),
create_user=AsyncMock(return_value=make_user()),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/register",
json={
"username": "tester",
"email": "tester@example.com",
"password": "secret123",
},
)
assert response.status_code == 201
assert response.json()["username"] == "tester"
def test_register_rejects_duplicate_username():
repo = SimpleNamespace(
user_exists=AsyncMock(side_effect=[True]),
create_user=AsyncMock(),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/register",
json={
"username": "tester",
"email": "tester@example.com",
"password": "secret123",
},
)
assert response.status_code == 400
assert response.json()["detail"] == "Username already registered"
repo.create_user.assert_not_awaited()
def test_login_supports_email_lookup():
hashed_password = get_password_hash("secret123")
repo = SimpleNamespace(
get_user_by_username=AsyncMock(return_value=None),
get_user_by_email=AsyncMock(
return_value=make_user(
email="tester@example.com",
hashed_password=hashed_password,
)
),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/login",
data={"username": "tester@example.com", "password": "secret123"},
)
assert response.status_code == 200
assert response.json()["token_type"] == "bearer"
repo.get_user_by_email.assert_awaited_once_with("tester@example.com")
def test_login_simple_uses_query_params():
hashed_password = get_password_hash("secret123")
repo = SimpleNamespace(
get_user_by_username=AsyncMock(
return_value=make_user(hashed_password=hashed_password)
),
get_user_by_email=AsyncMock(),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/login/simple",
params={"username": "tester", "password": "secret123"},
)
assert response.status_code == 200
assert response.json()["token_type"] == "bearer"
def test_me_returns_current_user_info():
client = _build_client(SimpleNamespace(), current_user=make_user(username="alice"))
response = client.get("/api/v1/auth/me")
assert response.status_code == 200
assert response.json()["username"] == "alice"
def test_refresh_rejects_access_token():
repo = SimpleNamespace(get_user_by_username=AsyncMock())
client = _build_client(repo)
response = client.post(
"/api/v1/auth/refresh",
params={"refresh_token": create_access_token("tester")},
)
assert response.status_code == 401
def test_refresh_success_returns_new_access_token():
repo = SimpleNamespace(
get_user_by_username=AsyncMock(return_value=make_user()),
)
client = _build_client(repo)
refresh_token = create_refresh_token("tester")
response = client.post(
"/api/v1/auth/refresh",
params={"refresh_token": refresh_token},
)
assert response.status_code == 200
payload = response.json()
assert payload["refresh_token"] == refresh_token
assert payload["token_type"] == "bearer"
+152
View File
@@ -0,0 +1,152 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from uuid import uuid4
from fastapi.testclient import TestClient
from tests.conftest import build_test_app, install_stub, load_module_from_path
class DummyChangeSet:
def __init__(self, operations=None):
if operations is None:
self.operations = []
elif isinstance(operations, dict):
self.operations = [operations]
else:
self.operations = operations
def _load_project_module(monkeypatch):
install_stub(monkeypatch, "app.services", package=True)
install_stub(
monkeypatch,
"app.services.project_info",
{},
)
install_stub(
monkeypatch,
"app.services.tjnetwork",
{
"ChangeSet": DummyChangeSet,
"list_project": lambda: ["demo"],
"have_project": lambda network: network == "demo",
"create_project": lambda network: None,
"delete_project": lambda network: None,
"is_project_open": lambda network: False,
"open_project": lambda network: None,
"close_project": lambda network: None,
"copy_project": lambda source, target: None,
"import_inp": lambda network, cs: {"ok": True},
"export_inp": lambda network, version: DummyChangeSet({"kind": "export"}),
"read_inp": lambda network, inp: True,
"dump_inp": lambda network, inp: True,
"get_all_vertices": lambda network: [],
"get_all_scada_elements": lambda network: [],
"get_all_district_metering_areas": lambda network: [],
"get_all_service_areas": lambda network: [],
"get_all_virtual_districts": lambda network: [],
"get_extension_data": lambda network, key: None,
"convert_inp_v3_to_v2": lambda inp: DummyChangeSet({"inp": inp}),
},
)
install_stub(
monkeypatch,
"app.auth.project_dependencies",
{"get_metadata_repository": lambda: None},
)
install_stub(
monkeypatch,
"app.infra.db.postgresql.database",
{"get_database_instance": lambda network: None},
)
install_stub(
monkeypatch,
"app.infra.db.timescaledb.database",
{"get_database_instance": lambda network: None},
)
return load_module_from_path(
"tests_project_endpoints_module",
"app/api/v1/endpoints/project.py",
)
def test_project_info_returns_404_when_missing(monkeypatch):
module = _load_project_module(monkeypatch)
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=None))
app = build_test_app(module.router, "/api/v1")
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
client = TestClient(app)
response = client.get("/api/v1/project_info/", params={"network": "missing"})
assert response.status_code == 404
assert response.json()["detail"] == "Project missing not found"
def test_project_info_returns_geoserver_payload(monkeypatch):
module = _load_project_module(monkeypatch)
detail = SimpleNamespace(
project_id=uuid4(),
name="Demo Project",
code="demo",
description="desc",
gs_workspace="ws",
map_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
status="active",
geoserver=SimpleNamespace(
gs_base_url="http://gs",
gs_admin_user="admin",
gs_datastore_name="store",
default_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
srid=4326,
),
)
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=detail))
app = build_test_app(module.router, "/api/v1")
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
client = TestClient(app)
response = client.get("/api/v1/project_info/", params={"network": "demo"})
assert response.status_code == 200
payload = response.json()
assert payload["code"] == "demo"
assert payload["geoserver"]["gs_base_url"] == "http://gs"
def test_open_project_returns_network_even_when_db_connection_fails(monkeypatch):
module = _load_project_module(monkeypatch)
called = []
monkeypatch.setattr(module, "open_project", lambda network: called.append(network))
async def failing_get_pg_db(network):
raise RuntimeError("db down")
monkeypatch.setattr(module, "get_pg_db", failing_get_pg_db)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post("/api/v1/openproject/", params={"network": "demo"})
assert response.status_code == 200
assert response.json() == "demo"
assert called == ["demo"]
def test_project_lock_lifecycle(monkeypatch):
module = _load_project_module(monkeypatch)
module.lockedPrjs.clear()
client = TestClient(build_test_app(module.router, "/api/v1"))
first_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
second_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
locked_by_me = client.get("/api/v1/isprojectlockedbyme/", params={"network": "demo"})
unlock = client.post("/api/v1/unlockproject/", params={"network": "demo"})
locked = client.get("/api/v1/isprojectlocked/", params={"network": "demo"})
assert first_lock.json() == 0
assert second_lock.json() == 1
assert locked_by_me.json() is True
assert unlock.json() is True
assert locked.json() is False
+154
View File
@@ -0,0 +1,154 @@
from typing import Any
from fastapi.testclient import TestClient
from tests.conftest import build_test_app, install_stub, load_module_from_path
class DummyChangeSet:
def __init__(self, operations=None):
if operations is None:
self.operations = []
elif isinstance(operations, dict):
self.operations = [operations]
else:
self.operations = operations
def _noop(*args, **kwargs):
return None
def _load_regions_module(monkeypatch):
install_stub(monkeypatch, "app.services", package=True)
install_stub(
monkeypatch,
"app.services.tjnetwork",
{
"Any": Any,
"ChangeSet": DummyChangeSet,
"add_district_metering_area": _noop,
"add_region": _noop,
"add_service_area": _noop,
"add_virtual_district": _noop,
"calculate_district_metering_area_for_network": lambda *args, **kwargs: [],
"calculate_district_metering_area_for_nodes": lambda *args, **kwargs: [],
"calculate_district_metering_area_for_region": lambda *args, **kwargs: [],
"calculate_service_area": lambda network: [],
"calculate_virtual_district": lambda *args, **kwargs: {},
"delete_district_metering_area": _noop,
"delete_region": _noop,
"delete_service_area": _noop,
"delete_virtual_district": _noop,
"generate_district_metering_area": _noop,
"generate_service_area": _noop,
"generate_sub_district_metering_area": _noop,
"generate_virtual_district": _noop,
"get_all_district_metering_area_ids": lambda network: [],
"get_all_district_metering_areas": lambda network: [],
"get_all_service_areas": lambda network: [],
"get_all_virtual_districts": lambda network: [],
"get_district_metering_area": lambda network, area_id: {},
"get_district_metering_area_schema": lambda network: {},
"get_region": lambda network, region_id: {},
"get_region_schema": lambda network: {},
"get_service_area": lambda network, area_id: {},
"get_service_area_schema": lambda network: {},
"get_virtual_district": lambda network, area_id: {},
"get_virtual_district_schema": lambda network: {},
"set_district_metering_area": _noop,
"set_region": _noop,
"set_service_area": _noop,
"set_virtual_district": _noop,
},
)
return load_module_from_path(
"tests_regions_endpoints_module",
"app/api/v1/endpoints/network/regions.py",
)
def test_removed_routes_are_absent_and_return_404(monkeypatch):
module = _load_regions_module(monkeypatch)
client = TestClient(build_test_app(module.router, "/api/v1"))
openapi = client.get("/openapi.json").json()
assert "/api/v1/calculateregion/" not in openapi["paths"]
assert "/api/v1/getallregions/" not in openapi["paths"]
assert "/api/v1/generateregion/" not in openapi["paths"]
assert "/api/v1/calculatedistrictmeteringarea/" not in openapi["paths"]
assert client.get("/api/v1/calculateregion/", params={"network": "demo", "time_index": 0}).status_code == 404
assert client.get("/api/v1/calculatedistrictmeteringarea/", params={"network": "demo"}).status_code == 404
def test_calculate_service_area_contract_uses_only_network(monkeypatch):
module = _load_regions_module(monkeypatch)
calls = []
monkeypatch.setattr(
module,
"calculate_service_area",
lambda network: calls.append(network) or [{"source-1": ["n1", "n2"]}],
)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.get(
"/api/v1/calculateservicearea/",
params={"network": "demo", "time_index": 5},
)
schema = client.get("/openapi.json").json()
assert response.status_code == 200
assert response.json() == [{"source-1": ["n1", "n2"]}]
assert calls == ["demo"]
parameter_names = [
item["name"]
for item in schema["paths"]["/api/v1/calculateservicearea/"]["get"]["parameters"]
]
assert parameter_names == ["network"]
def test_add_district_metering_area_converts_boundary_to_tuples(monkeypatch):
module = _load_regions_module(monkeypatch)
captured = {}
def fake_add(network, change_set):
captured["network"] = network
captured["boundary"] = change_set.operations[0]["boundary"]
return {"ok": True}
monkeypatch.setattr(module, "add_district_metering_area", fake_add)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/adddistrictmeteringarea/",
params={"network": "demo"},
json={"id": "dma-1", "boundary": [[1, 2], [3, 4], [1, 2]]},
)
assert response.status_code == 200
assert captured == {
"network": "demo",
"boundary": [(1, 2), (3, 4), (1, 2)],
}
def test_generate_virtual_district_reads_centers_from_body(monkeypatch):
module = _load_regions_module(monkeypatch)
captured = {}
def fake_generate(network, centers, inflate_delta):
captured["args"] = (network, centers, inflate_delta)
return {"generated": True}
monkeypatch.setattr(module, "generate_virtual_district", fake_generate)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/generatevirtualdistrict/",
params={"network": "demo", "inflate_delta": 0.75},
json={"centers": ["J1", "J2"]},
)
assert response.status_code == 200
assert captured["args"] == ("demo", ["J1", "J2"], 0.75)
+175
View File
@@ -0,0 +1,175 @@
from pathlib import Path
from fastapi.testclient import TestClient
from tests.conftest import build_test_app, install_stub, load_module_from_path
def _load_simulation_module(monkeypatch):
install_stub(monkeypatch, "app.services", package=True)
install_stub(
monkeypatch,
"app.services.simulation",
{"run_simulation": lambda **kwargs: None},
)
install_stub(monkeypatch, "app.services.globals", {})
install_stub(
monkeypatch,
"app.services.tjnetwork",
{
"run_project": lambda network: "report",
"run_project_return_dict": lambda network: {"output": {}, "report": "ok"},
"run_inp": lambda network: "inp-report",
"dump_output": lambda output: f"dump::{output}",
},
)
install_stub(monkeypatch, "app.algorithms", package=True)
install_stub(monkeypatch, "app.algorithms.simulation", package=True)
install_stub(
monkeypatch,
"app.algorithms.simulation.scenarios",
{
"burst_analysis": lambda *args, **kwargs: "burst",
"valve_close_analysis": lambda *args, **kwargs: "valve",
"flushing_analysis": lambda *args, **kwargs: "flush",
"contaminant_simulation": lambda *args, **kwargs: "contaminant",
"age_analysis": lambda *args, **kwargs: "age",
"pressure_regulation": lambda *args, **kwargs: "pressure",
},
)
install_stub(
monkeypatch,
"app.algorithms.sensor",
{
"pressure_sensor_placement_sensitivity": lambda *args, **kwargs: [],
"pressure_sensor_placement_kmeans": lambda *args, **kwargs: [],
},
)
install_stub(
monkeypatch,
"app.services.network_import",
{"network_update": lambda *args, **kwargs: "updated"},
)
install_stub(
monkeypatch,
"app.services.simulation_ops",
{
"project_management": lambda *args, **kwargs: "managed",
"scheduling_simulation": lambda *args, **kwargs: "scheduled",
"daily_scheduling_simulation": lambda *args, **kwargs: "daily",
},
)
install_stub(
monkeypatch,
"app.services.valve_isolation",
{"analyze_valve_isolation": lambda *args, **kwargs: {}},
)
return load_module_from_path(
"tests_simulation_endpoints_module",
"app/api/v1/endpoints/simulation.py",
)
def test_run_project_endpoint_returns_plain_text(monkeypatch):
module = _load_simulation_module(monkeypatch)
monkeypatch.setattr(module, "run_project", lambda network: f"report::{network}")
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.get("/api/v1/runproject/", params={"network": "demo"})
assert response.status_code == 200
assert response.text == "report::demo"
def test_scheduling_analysis_maps_request_body(monkeypatch):
module = _load_simulation_module(monkeypatch)
captured = {}
def fake_schedule(network, start_time, pump_control, tank_id, water_plant_output_id, time_delta):
captured["args"] = (
network,
start_time,
pump_control,
tank_id,
water_plant_output_id,
time_delta,
)
return "scheduled"
monkeypatch.setattr(module, "scheduling_simulation", fake_schedule)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/scheduling_analysis/",
json={
"network": "demo",
"start_time": "2025-01-01T08:00:00+08:00",
"pump_control": {"P1": [1, 0, 1]},
"tank_id": "T1",
"water_plant_output_id": "R1",
},
)
assert response.status_code == 200
assert response.json() == "scheduled"
assert captured["args"] == (
"demo",
"2025-01-01T08:00:00+08:00",
{"P1": [1, 0, 1]},
"T1",
"R1",
300,
)
def test_project_management_maps_named_arguments(monkeypatch):
module = _load_simulation_module(monkeypatch)
captured = {}
def fake_project_management(**kwargs):
captured.update(kwargs)
return "managed"
monkeypatch.setattr(module, "project_management", fake_project_management)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/project_management/",
json={
"network": "demo",
"start_time": "2025-01-01T08:00:00+08:00",
"pump_control": {"P1": [1]},
"tank_init_level": {"T1": 10.0},
"region_demand": {"R1": 20.0},
},
)
assert response.status_code == 200
assert response.json() == "managed"
assert captured == {
"prj_name": "demo",
"start_datetime": "2025-01-01T08:00:00+08:00",
"pump_control": {"P1": [1]},
"tank_initial_level_control": {"T1": 10.0},
"region_demand_control": {"R1": 20.0},
}
def test_network_update_surfaces_service_error(monkeypatch, tmp_path):
module = _load_simulation_module(monkeypatch)
monkeypatch.chdir(tmp_path)
def boom(_path):
raise RuntimeError("write failed")
monkeypatch.setattr(module, "network_update", boom)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/network_update/",
files={"file": ("update.txt", b"payload")},
)
assert response.status_code == 500
assert "数据库操作失败: write failed" in response.json()["detail"]
assert list(Path(tmp_path).glob("network_update_*"))
@@ -0,0 +1,95 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from fastapi.testclient import TestClient
from app.api.v1.endpoints import user_management as user_management_endpoint
from app.auth.dependencies import get_current_active_user, get_user_repository
from app.auth.permissions import get_current_admin
from app.domain.models.role import UserRole
from tests.conftest import build_test_app, make_user
def _build_client(repo, *, current_user=None, admin_user=None) -> TestClient:
app = build_test_app(user_management_endpoint.router, "/users")
app.dependency_overrides[get_user_repository] = lambda: repo
if current_user is not None:
app.dependency_overrides[get_current_active_user] = lambda: current_user
if admin_user is not None:
app.dependency_overrides[get_current_admin] = lambda: admin_user
return TestClient(app)
def test_list_users_requires_admin_role():
repo = SimpleNamespace(
get_all_users=AsyncMock(
return_value=[
make_user(id=1, username="admin", role=UserRole.ADMIN),
make_user(id=2, username="user2"),
]
)
)
client = _build_client(
repo,
current_user=make_user(id=1, role=UserRole.ADMIN),
)
response = client.get("/users/", params={"skip": 5, "limit": 2})
assert response.status_code == 200
assert len(response.json()) == 2
repo.get_all_users.assert_awaited_once_with(skip=5, limit=2)
def test_get_user_rejects_non_owner_non_admin():
repo = SimpleNamespace(get_user_by_id=AsyncMock())
client = _build_client(repo, current_user=make_user(id=2, role=UserRole.USER))
response = client.get("/users/3")
assert response.status_code == 403
assert response.json()["detail"] == "You don't have permission to view this user"
repo.get_user_by_id.assert_not_awaited()
def test_update_user_blocks_role_change_for_non_admin():
repo = SimpleNamespace(
get_user_by_id=AsyncMock(return_value=make_user(id=1)),
update_user=AsyncMock(),
)
client = _build_client(repo, current_user=make_user(id=1, role=UserRole.USER))
response = client.put("/users/1", json={"role": "ADMIN"})
assert response.status_code == 403
assert response.json()["detail"] == "Only admins can change user roles"
repo.update_user.assert_not_awaited()
def test_delete_user_blocks_self_delete_for_admin():
admin_user = make_user(id=1, role=UserRole.ADMIN, is_superuser=True)
repo = SimpleNamespace(delete_user=AsyncMock())
client = _build_client(repo, admin_user=admin_user)
response = client.delete("/users/1")
assert response.status_code == 400
assert response.json()["detail"] == "You cannot delete your own account"
repo.delete_user.assert_not_awaited()
def test_activate_user_updates_active_flag():
repo = SimpleNamespace(
update_user=AsyncMock(return_value=make_user(id=2, is_active=True)),
)
client = _build_client(
repo,
admin_user=make_user(id=1, role=UserRole.ADMIN, is_superuser=True),
)
response = client.post("/users/2/activate")
assert response.status_code == 200
assert response.json()["is_active"] is True
user_update = repo.update_user.await_args.args[1]
assert user_update.is_active is True
+36
View File
@@ -0,0 +1,36 @@
from jose import jwt
from app.core.config import settings
from app.core.security import (
create_access_token,
create_refresh_token,
get_password_hash,
verify_password,
)
def test_password_hash_roundtrip():
hashed = get_password_hash("secret123")
assert hashed != "secret123"
assert verify_password("secret123", hashed) is True
assert verify_password("wrong", hashed) is False
def test_create_access_token_sets_access_type():
token = create_access_token("alice")
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
assert payload["sub"] == "alice"
assert payload["type"] == "access"
assert "exp" in payload
assert "iat" in payload
def test_create_refresh_token_sets_refresh_type():
token = create_refresh_token("alice")
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
assert payload["sub"] == "alice"
assert payload["type"] == "refresh"
assert "exp" in payload
assert "iat" in payload
+189 -6
View File
@@ -1,14 +1,197 @@
import pytest
import sys
import importlib
import importlib.util
import os
import sys
import types
from datetime import datetime, timezone
from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
import pytest
from fastapi import FastAPI
# 自动添加项目根目录到路径(处理项目结构)
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
PROJECT_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PROJECT_ROOT))
def run_this_test(test_file):
"""自定义函数:运行单个测试文件(类似pytest)"""
# 提取测试文件名(无扩展名)
test_name = os.path.splitext(os.path.basename(test_file))[0]
# 使用pytest运行(自动处理导入)
pytest.main([test_file, "-v"])
def build_test_app(router, prefix: str = "") -> FastAPI:
app = FastAPI()
app.include_router(router, prefix=prefix)
return app
def load_module_from_path(module_name: str, relative_path: str):
module_path = PROJECT_ROOT / relative_path
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module)
return module
def install_stub(monkeypatch, name: str, attrs: dict | None = None, package: bool = False):
module = types.ModuleType(name)
if package:
module.__path__ = []
if attrs:
for key, value in attrs.items():
setattr(module, key, value)
monkeypatch.setitem(sys.modules, name, module)
parent_name, _, child_name = name.rpartition(".")
if parent_name:
parent = sys.modules.get(parent_name)
if parent is None:
try:
parent = importlib.import_module(parent_name)
except Exception:
parent = types.ModuleType(parent_name)
parent.__path__ = []
monkeypatch.setitem(sys.modules, parent_name, parent)
setattr(parent, child_name, module)
return module
class FakeCursor:
def __init__(
self,
*,
fetchone_results=None,
fetchall_results=None,
rowcount: int = 0,
rowcounts=None,
):
self._fetchone_results = list(fetchone_results or [])
self._fetchall_results = list(fetchall_results or [])
self._rowcounts = list(rowcounts or [])
self.rowcount = rowcount
self.executed: list[tuple[str, dict | tuple | None]] = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def execute(self, query, params=None):
self.executed.append((str(query), params))
if self._rowcounts:
self.rowcount = self._rowcounts.pop(0)
async def fetchone(self):
if self._fetchone_results:
return self._fetchone_results.pop(0)
return None
async def fetchall(self):
if self._fetchall_results:
return self._fetchall_results.pop(0)
return []
class FakeConnection:
def __init__(self, cursor: FakeCursor):
self.cursor_instance = cursor
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
def cursor(self):
return self.cursor_instance
class FakeDB:
def __init__(self, cursor: FakeCursor):
self.connection = FakeConnection(cursor)
def get_connection(self):
return self.connection
class FakeExecuteResult:
def __init__(self, *, rows=None, scalar_value=None):
self._rows = list(rows or [])
self._scalar_value = scalar_value
def scalars(self):
return self
def all(self):
return self._rows
def scalar(self):
return self._scalar_value
class FakeAsyncSession:
def __init__(self, execute_results=None):
self._execute_results = list(execute_results or [])
self.executed = []
self.added = []
self.commit_count = 0
self.refreshed = []
def add(self, obj):
self.added.append(obj)
async def execute(self, stmt):
self.executed.append(stmt)
if self._execute_results:
return self._execute_results.pop(0)
return FakeExecuteResult()
async def commit(self):
self.commit_count += 1
async def refresh(self, obj):
self.refreshed.append(obj)
def make_user(**overrides):
from app.domain.models.role import UserRole
from app.domain.schemas.user import UserInDB
data = {
"id": 1,
"username": "tester",
"email": "tester@example.com",
"hashed_password": "hashed-password",
"role": UserRole.USER,
"is_active": True,
"is_superuser": False,
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
data.update(overrides)
return UserInDB(**data)
def make_audit_log(**overrides):
data = {
"id": uuid4(),
"user_id": uuid4(),
"project_id": uuid4(),
"action": "LOGIN",
"resource_type": "user",
"resource_id": "1",
"ip_address": "127.0.0.1",
"request_method": "GET",
"request_path": "/audit/logs",
"request_data": {"ok": True},
"response_status": 200,
"timestamp": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
data.update(overrides)
return SimpleNamespace(**data)
+79
View File
@@ -0,0 +1,79 @@
import asyncio
from datetime import datetime, timezone
from uuid import uuid4
from app.infra.db.metadb.repositories.audit_repository import AuditRepository
from tests.conftest import FakeAsyncSession, FakeExecuteResult, make_audit_log
def test_create_log_adds_commits_and_refreshes(monkeypatch):
class FakeAuditLog:
def __init__(self, **kwargs):
self.id = uuid4()
for key, value in kwargs.items():
setattr(self, key, value)
session = FakeAsyncSession()
repo = AuditRepository(session)
monkeypatch.setattr(
"app.infra.db.metadb.repositories.audit_repository.models.AuditLog",
FakeAuditLog,
)
result = asyncio.run(
repo.create_log(
action="LOGIN",
request_method="POST",
request_path="/auth/login",
response_status=200,
)
)
assert result.action == "LOGIN"
assert result.request_method == "POST"
assert session.commit_count == 1
assert len(session.added) == 1
assert len(session.refreshed) == 1
def test_get_logs_builds_filtered_query_and_returns_models():
log = make_audit_log(action="UPDATE_USER", resource_type="user")
session = FakeAsyncSession(
execute_results=[FakeExecuteResult(rows=[log])],
)
repo = AuditRepository(session)
user_id = uuid4()
project_id = uuid4()
start_time = datetime(2025, 1, 1, tzinfo=timezone.utc)
results = asyncio.run(
repo.get_logs(
user_id=user_id,
project_id=project_id,
action="UPDATE_USER",
resource_type="user",
start_time=start_time,
skip=5,
limit=10,
)
)
assert len(results) == 1
assert results[0].action == "UPDATE_USER"
stmt = session.executed[0]
assert len(stmt._where_criteria) == 5
assert stmt._offset == 5
assert stmt._limit == 10
def test_get_log_count_returns_zero_when_scalar_none():
session = FakeAsyncSession(
execute_results=[FakeExecuteResult(scalar_value=None)],
)
repo = AuditRepository(session)
result = asyncio.run(repo.get_log_count(action="DELETE_USER"))
assert result == 0
stmt = session.executed[0]
assert len(stmt._where_criteria) == 1
+97
View File
@@ -0,0 +1,97 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from fastapi import HTTPException
from app.auth import dependencies
from app.core.security import create_access_token, create_refresh_token
from tests.conftest import make_user
def test_get_db_returns_app_state_db():
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace(db="db-instance")))
result = asyncio.run(dependencies.get_db(request))
assert result == "db-instance"
def test_get_db_raises_when_database_missing():
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace()))
with pytest.raises(HTTPException) as exc_info:
asyncio.run(dependencies.get_db(request))
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "Database not initialized"
def test_get_current_user_accepts_valid_access_token():
repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=make_user()))
result = asyncio.run(
dependencies.get_current_user(
token=create_access_token("tester"),
user_repo=repo,
)
)
assert result.username == "tester"
repo.get_user_by_username.assert_awaited_once_with("tester")
def test_get_current_user_rejects_refresh_token():
repo = SimpleNamespace(get_user_by_username=AsyncMock())
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
dependencies.get_current_user(
token=create_refresh_token("tester"),
user_repo=repo,
)
)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid token type. Access token required."
repo.get_user_by_username.assert_not_awaited()
def test_get_current_user_rejects_missing_user():
repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=None))
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
dependencies.get_current_user(
token=create_access_token("ghost"),
user_repo=repo,
)
)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Could not validate credentials"
def test_get_current_active_user_rejects_inactive_user():
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
dependencies.get_current_active_user(
current_user=make_user(is_active=False),
)
)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "Inactive user"
def test_get_current_superuser_rejects_non_superuser():
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
dependencies.get_current_superuser(
current_user=make_user(is_superuser=False),
)
)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "Not enough privileges. Superuser access required."
+56
View File
@@ -0,0 +1,56 @@
import asyncio
import pytest
from fastapi import HTTPException
from app.auth import permissions
from app.domain.models.role import UserRole
from tests.conftest import make_user
def test_require_role_allows_higher_privilege_user():
checker = permissions.require_role(UserRole.OPERATOR)
result = asyncio.run(checker(current_user=make_user(role=UserRole.ADMIN)))
assert result.role == UserRole.ADMIN
def test_require_role_rejects_insufficient_role():
checker = permissions.require_role(UserRole.ADMIN)
with pytest.raises(HTTPException) as exc_info:
asyncio.run(checker(current_user=make_user(role=UserRole.USER)))
assert exc_info.value.status_code == 403
assert "Required role: ADMIN" in exc_info.value.detail
def test_check_resource_owner_allows_admin():
assert permissions.check_resource_owner(
99,
make_user(id=1, role=UserRole.ADMIN),
) is True
def test_check_resource_owner_allows_owner():
assert permissions.check_resource_owner(
7,
make_user(id=7, role=UserRole.USER),
) is True
def test_check_resource_owner_rejects_other_user():
assert permissions.check_resource_owner(
7,
make_user(id=8, role=UserRole.USER),
) is False
def test_require_owner_or_admin_rejects_other_user():
checker = permissions.require_owner_or_admin(7)
with pytest.raises(HTTPException) as exc_info:
asyncio.run(checker(current_user=make_user(id=8, role=UserRole.USER)))
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "You don't have permission to access this resource"
+19 -18
View File
@@ -1,9 +1,8 @@
import asyncio
from datetime import datetime, timezone
import importlib.util
from pathlib import Path
import pytest
def _load_scada_repository():
module_path = (
@@ -50,18 +49,19 @@ class _FakeConnection:
return self.cursor_instance
@pytest.mark.asyncio
async def test_update_scada_field_inserts_when_update_hits_no_rows():
def test_update_scada_field_inserts_when_update_hits_no_rows():
ScadaRepository = _load_scada_repository()
conn = _FakeConnection(initial_rowcount=0)
point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)
await ScadaRepository.update_scada_field(
conn,
point_time,
"170490",
"cleaned_value",
26.5,
asyncio.run(
ScadaRepository.update_scada_field(
conn,
point_time,
"170490",
"cleaned_value",
26.5,
)
)
assert len(conn.cursor_instance.calls) == 2
@@ -69,18 +69,19 @@ async def test_update_scada_field_inserts_when_update_hits_no_rows():
assert "INSERT INTO scada.scada_data" in conn.cursor_instance.calls[1][0]
@pytest.mark.asyncio
async def test_update_scada_field_skips_insert_when_update_succeeds():
def test_update_scada_field_skips_insert_when_update_succeeds():
ScadaRepository = _load_scada_repository()
conn = _FakeConnection(initial_rowcount=1)
point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)
await ScadaRepository.update_scada_field(
conn,
point_time,
"170490",
"cleaned_value",
26.5,
asyncio.run(
ScadaRepository.update_scada_field(
conn,
point_time,
"170490",
"cleaned_value",
26.5,
)
)
assert len(conn.cursor_instance.calls) == 1
+124
View File
@@ -0,0 +1,124 @@
import asyncio
from unittest.mock import AsyncMock
import pytest
from app.domain.models.role import UserRole
from app.domain.schemas.user import UserCreate, UserUpdate
from app.infra.db.metadb.repositories.user_repository import UserRepository
from tests.conftest import FakeCursor, FakeDB
def _user_row(**overrides):
base = {
"id": 1,
"username": "tester",
"email": "tester@example.com",
"hashed_password": "hashed-password",
"role": "USER",
"is_active": True,
"is_superuser": False,
"created_at": "2025-01-01T00:00:00+00:00",
"updated_at": "2025-01-01T00:00:00+00:00",
}
base.update(overrides)
return base
def test_create_user_hashes_password_and_returns_model(monkeypatch):
cursor = FakeCursor(fetchone_results=[_user_row()])
repo = UserRepository(FakeDB(cursor))
monkeypatch.setattr(
"app.infra.db.metadb.repositories.user_repository.get_password_hash",
lambda password: f"hashed::{password}",
)
result = asyncio.run(
repo.create_user(
UserCreate(
username="tester",
email="tester@example.com",
password="secret123",
)
)
)
assert result is not None
assert result.username == "tester"
assert cursor.executed[0][1]["hashed_password"] == "hashed::secret123"
def test_update_user_without_fields_returns_existing_user(monkeypatch):
repo = UserRepository(FakeDB(FakeCursor()))
existing_user = AsyncMock(return_value="existing")
monkeypatch.setattr(repo, "get_user_by_id", existing_user)
result = asyncio.run(repo.update_user(1, UserUpdate()))
assert result == "existing"
existing_user.assert_awaited_once_with(1)
def test_update_user_builds_dynamic_query(monkeypatch):
cursor = FakeCursor(fetchone_results=[_user_row(role="ADMIN", email="new@example.com")])
repo = UserRepository(FakeDB(cursor))
monkeypatch.setattr(
"app.infra.db.metadb.repositories.user_repository.get_password_hash",
lambda password: f"hashed::{password}",
)
result = asyncio.run(
repo.update_user(
1,
UserUpdate(
email="new@example.com",
password="new-secret",
role=UserRole.ADMIN,
is_active=False,
),
),
)
assert result is not None
query, params = cursor.executed[0]
assert "email = %(email)s" in query
assert "hashed_password = %(hashed_password)s" in query
assert "role = %(role)s" in query
assert "is_active = %(is_active)s" in query
assert params["hashed_password"] == "hashed::new-secret"
assert params["role"] == "ADMIN"
assert params["is_active"] is False
def test_delete_user_returns_false_when_execute_raises():
cursor = FakeCursor()
cursor.execute = AsyncMock(side_effect=RuntimeError("boom"))
repo = UserRepository(FakeDB(cursor))
result = asyncio.run(repo.delete_user(1))
assert result is False
def test_user_exists_short_circuits_without_filters():
cursor = FakeCursor()
repo = UserRepository(FakeDB(cursor))
result = asyncio.run(repo.user_exists())
assert result is False
assert cursor.executed == []
def test_user_exists_checks_username_or_email():
cursor = FakeCursor(fetchone_results=[{"exists": True}])
repo = UserRepository(FakeDB(cursor))
result = asyncio.run(
repo.user_exists(username="tester", email="tester@example.com")
)
assert result is True
query, params = cursor.executed[0]
assert "username = %(username)s OR email = %(email)s" in query
assert params == {"username": "tester", "email": "tester@example.com"}