新增 API 测试用例,修复失效接口问题
This commit is contained in:
@@ -0,0 +1,91 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints import audit as audit_endpoint
|
||||
from app.auth.metadata_dependencies import (
|
||||
get_current_metadata_admin,
|
||||
get_current_metadata_user,
|
||||
)
|
||||
from tests.conftest import build_test_app, make_audit_log
|
||||
|
||||
|
||||
def _build_client(repo, *, metadata_admin=None, metadata_user=None) -> TestClient:
|
||||
app = build_test_app(audit_endpoint.router, "/audit")
|
||||
app.dependency_overrides[audit_endpoint.get_audit_repository] = lambda: repo
|
||||
if metadata_admin is not None:
|
||||
app.dependency_overrides[get_current_metadata_admin] = lambda: metadata_admin
|
||||
if metadata_user is not None:
|
||||
app.dependency_overrides[get_current_metadata_user] = lambda: metadata_user
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_get_audit_logs_passes_filters():
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_logs": AsyncMock(return_value=[make_audit_log(action="LOGIN")]),
|
||||
"get_log_count": AsyncMock(),
|
||||
},
|
||||
)()
|
||||
client = _build_client(repo, metadata_admin=object())
|
||||
|
||||
response = client.get(
|
||||
"/audit/logs",
|
||||
params={
|
||||
"action": "LOGIN",
|
||||
"resource_type": "user",
|
||||
"skip": 2,
|
||||
"limit": 5,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()[0]["action"] == "LOGIN"
|
||||
repo.get_logs.assert_awaited_once()
|
||||
kwargs = repo.get_logs.await_args.kwargs
|
||||
assert kwargs["action"] == "LOGIN"
|
||||
assert kwargs["resource_type"] == "user"
|
||||
assert kwargs["skip"] == 2
|
||||
assert kwargs["limit"] == 5
|
||||
|
||||
|
||||
def test_get_audit_logs_count_returns_count_payload():
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_logs": AsyncMock(),
|
||||
"get_log_count": AsyncMock(return_value=7),
|
||||
},
|
||||
)()
|
||||
client = _build_client(repo, metadata_admin=object())
|
||||
|
||||
response = client.get("/audit/logs/count", params={"action": "DELETE_USER"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"count": 7}
|
||||
repo.get_log_count.assert_awaited_once()
|
||||
assert repo.get_log_count.await_args.kwargs["action"] == "DELETE_USER"
|
||||
|
||||
|
||||
def test_get_my_audit_logs_forces_current_user_id():
|
||||
current_user = type("User", (), {"id": make_audit_log().user_id})()
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_logs": AsyncMock(return_value=[make_audit_log(user_id=current_user.id)]),
|
||||
"get_log_count": AsyncMock(),
|
||||
},
|
||||
)()
|
||||
client = _build_client(repo, metadata_user=current_user)
|
||||
|
||||
response = client.get("/audit/logs/my", params={"limit": 3})
|
||||
|
||||
assert response.status_code == 200
|
||||
repo.get_logs.assert_awaited_once()
|
||||
kwargs = repo.get_logs.await_args.kwargs
|
||||
assert kwargs["user_id"] == current_user.id
|
||||
assert kwargs["limit"] == 3
|
||||
@@ -0,0 +1,139 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints import auth as auth_endpoint
|
||||
from app.auth.dependencies import get_current_active_user, get_user_repository
|
||||
from app.core.security import create_access_token, create_refresh_token, get_password_hash
|
||||
from tests.conftest import build_test_app, make_user
|
||||
|
||||
|
||||
def _build_client(repo, current_user=None) -> TestClient:
|
||||
app = build_test_app(auth_endpoint.router, "/api/v1/auth")
|
||||
app.dependency_overrides[get_user_repository] = lambda: repo
|
||||
if current_user is not None:
|
||||
app.dependency_overrides[get_current_active_user] = lambda: current_user
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_register_success():
|
||||
repo = SimpleNamespace(
|
||||
user_exists=AsyncMock(side_effect=[False, False]),
|
||||
create_user=AsyncMock(return_value=make_user()),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["username"] == "tester"
|
||||
|
||||
|
||||
def test_register_rejects_duplicate_username():
|
||||
repo = SimpleNamespace(
|
||||
user_exists=AsyncMock(side_effect=[True]),
|
||||
create_user=AsyncMock(),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Username already registered"
|
||||
repo.create_user.assert_not_awaited()
|
||||
|
||||
|
||||
def test_login_supports_email_lookup():
|
||||
hashed_password = get_password_hash("secret123")
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_username=AsyncMock(return_value=None),
|
||||
get_user_by_email=AsyncMock(
|
||||
return_value=make_user(
|
||||
email="tester@example.com",
|
||||
hashed_password=hashed_password,
|
||||
)
|
||||
),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "tester@example.com", "password": "secret123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
repo.get_user_by_email.assert_awaited_once_with("tester@example.com")
|
||||
|
||||
|
||||
def test_login_simple_uses_query_params():
|
||||
hashed_password = get_password_hash("secret123")
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_username=AsyncMock(
|
||||
return_value=make_user(hashed_password=hashed_password)
|
||||
),
|
||||
get_user_by_email=AsyncMock(),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/simple",
|
||||
params={"username": "tester", "password": "secret123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
|
||||
|
||||
def test_me_returns_current_user_info():
|
||||
client = _build_client(SimpleNamespace(), current_user=make_user(username="alice"))
|
||||
|
||||
response = client.get("/api/v1/auth/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["username"] == "alice"
|
||||
|
||||
|
||||
def test_refresh_rejects_access_token():
|
||||
repo = SimpleNamespace(get_user_by_username=AsyncMock())
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
params={"refresh_token": create_access_token("tester")},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_refresh_success_returns_new_access_token():
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_username=AsyncMock(return_value=make_user()),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
refresh_token = create_refresh_token("tester")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
params={"refresh_token": refresh_token},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["refresh_token"] == refresh_token
|
||||
assert payload["token_type"] == "bearer"
|
||||
@@ -0,0 +1,152 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tests.conftest import build_test_app, install_stub, load_module_from_path
|
||||
|
||||
|
||||
class DummyChangeSet:
|
||||
def __init__(self, operations=None):
|
||||
if operations is None:
|
||||
self.operations = []
|
||||
elif isinstance(operations, dict):
|
||||
self.operations = [operations]
|
||||
else:
|
||||
self.operations = operations
|
||||
|
||||
|
||||
def _load_project_module(monkeypatch):
|
||||
install_stub(monkeypatch, "app.services", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.project_info",
|
||||
{},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.tjnetwork",
|
||||
{
|
||||
"ChangeSet": DummyChangeSet,
|
||||
"list_project": lambda: ["demo"],
|
||||
"have_project": lambda network: network == "demo",
|
||||
"create_project": lambda network: None,
|
||||
"delete_project": lambda network: None,
|
||||
"is_project_open": lambda network: False,
|
||||
"open_project": lambda network: None,
|
||||
"close_project": lambda network: None,
|
||||
"copy_project": lambda source, target: None,
|
||||
"import_inp": lambda network, cs: {"ok": True},
|
||||
"export_inp": lambda network, version: DummyChangeSet({"kind": "export"}),
|
||||
"read_inp": lambda network, inp: True,
|
||||
"dump_inp": lambda network, inp: True,
|
||||
"get_all_vertices": lambda network: [],
|
||||
"get_all_scada_elements": lambda network: [],
|
||||
"get_all_district_metering_areas": lambda network: [],
|
||||
"get_all_service_areas": lambda network: [],
|
||||
"get_all_virtual_districts": lambda network: [],
|
||||
"get_extension_data": lambda network, key: None,
|
||||
"convert_inp_v3_to_v2": lambda inp: DummyChangeSet({"inp": inp}),
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.auth.project_dependencies",
|
||||
{"get_metadata_repository": lambda: None},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.infra.db.postgresql.database",
|
||||
{"get_database_instance": lambda network: None},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.infra.db.timescaledb.database",
|
||||
{"get_database_instance": lambda network: None},
|
||||
)
|
||||
return load_module_from_path(
|
||||
"tests_project_endpoints_module",
|
||||
"app/api/v1/endpoints/project.py",
|
||||
)
|
||||
|
||||
|
||||
def test_project_info_returns_404_when_missing(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=None))
|
||||
app = build_test_app(module.router, "/api/v1")
|
||||
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/project_info/", params={"network": "missing"})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Project missing not found"
|
||||
|
||||
|
||||
def test_project_info_returns_geoserver_payload(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
detail = SimpleNamespace(
|
||||
project_id=uuid4(),
|
||||
name="Demo Project",
|
||||
code="demo",
|
||||
description="desc",
|
||||
gs_workspace="ws",
|
||||
map_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
|
||||
status="active",
|
||||
geoserver=SimpleNamespace(
|
||||
gs_base_url="http://gs",
|
||||
gs_admin_user="admin",
|
||||
gs_datastore_name="store",
|
||||
default_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
|
||||
srid=4326,
|
||||
),
|
||||
)
|
||||
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=detail))
|
||||
app = build_test_app(module.router, "/api/v1")
|
||||
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/project_info/", params={"network": "demo"})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["code"] == "demo"
|
||||
assert payload["geoserver"]["gs_base_url"] == "http://gs"
|
||||
|
||||
|
||||
def test_open_project_returns_network_even_when_db_connection_fails(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
called = []
|
||||
|
||||
monkeypatch.setattr(module, "open_project", lambda network: called.append(network))
|
||||
|
||||
async def failing_get_pg_db(network):
|
||||
raise RuntimeError("db down")
|
||||
|
||||
monkeypatch.setattr(module, "get_pg_db", failing_get_pg_db)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post("/api/v1/openproject/", params={"network": "demo"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "demo"
|
||||
assert called == ["demo"]
|
||||
|
||||
|
||||
def test_project_lock_lifecycle(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
module.lockedPrjs.clear()
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
first_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
|
||||
second_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
|
||||
locked_by_me = client.get("/api/v1/isprojectlockedbyme/", params={"network": "demo"})
|
||||
unlock = client.post("/api/v1/unlockproject/", params={"network": "demo"})
|
||||
locked = client.get("/api/v1/isprojectlocked/", params={"network": "demo"})
|
||||
|
||||
assert first_lock.json() == 0
|
||||
assert second_lock.json() == 1
|
||||
assert locked_by_me.json() is True
|
||||
assert unlock.json() is True
|
||||
assert locked.json() is False
|
||||
@@ -0,0 +1,154 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tests.conftest import build_test_app, install_stub, load_module_from_path
|
||||
|
||||
|
||||
class DummyChangeSet:
|
||||
def __init__(self, operations=None):
|
||||
if operations is None:
|
||||
self.operations = []
|
||||
elif isinstance(operations, dict):
|
||||
self.operations = [operations]
|
||||
else:
|
||||
self.operations = operations
|
||||
|
||||
|
||||
def _noop(*args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def _load_regions_module(monkeypatch):
|
||||
install_stub(monkeypatch, "app.services", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.tjnetwork",
|
||||
{
|
||||
"Any": Any,
|
||||
"ChangeSet": DummyChangeSet,
|
||||
"add_district_metering_area": _noop,
|
||||
"add_region": _noop,
|
||||
"add_service_area": _noop,
|
||||
"add_virtual_district": _noop,
|
||||
"calculate_district_metering_area_for_network": lambda *args, **kwargs: [],
|
||||
"calculate_district_metering_area_for_nodes": lambda *args, **kwargs: [],
|
||||
"calculate_district_metering_area_for_region": lambda *args, **kwargs: [],
|
||||
"calculate_service_area": lambda network: [],
|
||||
"calculate_virtual_district": lambda *args, **kwargs: {},
|
||||
"delete_district_metering_area": _noop,
|
||||
"delete_region": _noop,
|
||||
"delete_service_area": _noop,
|
||||
"delete_virtual_district": _noop,
|
||||
"generate_district_metering_area": _noop,
|
||||
"generate_service_area": _noop,
|
||||
"generate_sub_district_metering_area": _noop,
|
||||
"generate_virtual_district": _noop,
|
||||
"get_all_district_metering_area_ids": lambda network: [],
|
||||
"get_all_district_metering_areas": lambda network: [],
|
||||
"get_all_service_areas": lambda network: [],
|
||||
"get_all_virtual_districts": lambda network: [],
|
||||
"get_district_metering_area": lambda network, area_id: {},
|
||||
"get_district_metering_area_schema": lambda network: {},
|
||||
"get_region": lambda network, region_id: {},
|
||||
"get_region_schema": lambda network: {},
|
||||
"get_service_area": lambda network, area_id: {},
|
||||
"get_service_area_schema": lambda network: {},
|
||||
"get_virtual_district": lambda network, area_id: {},
|
||||
"get_virtual_district_schema": lambda network: {},
|
||||
"set_district_metering_area": _noop,
|
||||
"set_region": _noop,
|
||||
"set_service_area": _noop,
|
||||
"set_virtual_district": _noop,
|
||||
},
|
||||
)
|
||||
return load_module_from_path(
|
||||
"tests_regions_endpoints_module",
|
||||
"app/api/v1/endpoints/network/regions.py",
|
||||
)
|
||||
|
||||
|
||||
def test_removed_routes_are_absent_and_return_404(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
openapi = client.get("/openapi.json").json()
|
||||
|
||||
assert "/api/v1/calculateregion/" not in openapi["paths"]
|
||||
assert "/api/v1/getallregions/" not in openapi["paths"]
|
||||
assert "/api/v1/generateregion/" not in openapi["paths"]
|
||||
assert "/api/v1/calculatedistrictmeteringarea/" not in openapi["paths"]
|
||||
assert client.get("/api/v1/calculateregion/", params={"network": "demo", "time_index": 0}).status_code == 404
|
||||
assert client.get("/api/v1/calculatedistrictmeteringarea/", params={"network": "demo"}).status_code == 404
|
||||
|
||||
|
||||
def test_calculate_service_area_contract_uses_only_network(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"calculate_service_area",
|
||||
lambda network: calls.append(network) or [{"source-1": ["n1", "n2"]}],
|
||||
)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/calculateservicearea/",
|
||||
params={"network": "demo", "time_index": 5},
|
||||
)
|
||||
schema = client.get("/openapi.json").json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [{"source-1": ["n1", "n2"]}]
|
||||
assert calls == ["demo"]
|
||||
parameter_names = [
|
||||
item["name"]
|
||||
for item in schema["paths"]["/api/v1/calculateservicearea/"]["get"]["parameters"]
|
||||
]
|
||||
assert parameter_names == ["network"]
|
||||
|
||||
|
||||
def test_add_district_metering_area_converts_boundary_to_tuples(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_add(network, change_set):
|
||||
captured["network"] = network
|
||||
captured["boundary"] = change_set.operations[0]["boundary"]
|
||||
return {"ok": True}
|
||||
|
||||
monkeypatch.setattr(module, "add_district_metering_area", fake_add)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/adddistrictmeteringarea/",
|
||||
params={"network": "demo"},
|
||||
json={"id": "dma-1", "boundary": [[1, 2], [3, 4], [1, 2]]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert captured == {
|
||||
"network": "demo",
|
||||
"boundary": [(1, 2), (3, 4), (1, 2)],
|
||||
}
|
||||
|
||||
|
||||
def test_generate_virtual_district_reads_centers_from_body(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_generate(network, centers, inflate_delta):
|
||||
captured["args"] = (network, centers, inflate_delta)
|
||||
return {"generated": True}
|
||||
|
||||
monkeypatch.setattr(module, "generate_virtual_district", fake_generate)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/generatevirtualdistrict/",
|
||||
params={"network": "demo", "inflate_delta": 0.75},
|
||||
json={"centers": ["J1", "J2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert captured["args"] == ("demo", ["J1", "J2"], 0.75)
|
||||
@@ -0,0 +1,175 @@
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tests.conftest import build_test_app, install_stub, load_module_from_path
|
||||
|
||||
|
||||
def _load_simulation_module(monkeypatch):
|
||||
install_stub(monkeypatch, "app.services", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.simulation",
|
||||
{"run_simulation": lambda **kwargs: None},
|
||||
)
|
||||
install_stub(monkeypatch, "app.services.globals", {})
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.tjnetwork",
|
||||
{
|
||||
"run_project": lambda network: "report",
|
||||
"run_project_return_dict": lambda network: {"output": {}, "report": "ok"},
|
||||
"run_inp": lambda network: "inp-report",
|
||||
"dump_output": lambda output: f"dump::{output}",
|
||||
},
|
||||
)
|
||||
install_stub(monkeypatch, "app.algorithms", package=True)
|
||||
install_stub(monkeypatch, "app.algorithms.simulation", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.algorithms.simulation.scenarios",
|
||||
{
|
||||
"burst_analysis": lambda *args, **kwargs: "burst",
|
||||
"valve_close_analysis": lambda *args, **kwargs: "valve",
|
||||
"flushing_analysis": lambda *args, **kwargs: "flush",
|
||||
"contaminant_simulation": lambda *args, **kwargs: "contaminant",
|
||||
"age_analysis": lambda *args, **kwargs: "age",
|
||||
"pressure_regulation": lambda *args, **kwargs: "pressure",
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.algorithms.sensor",
|
||||
{
|
||||
"pressure_sensor_placement_sensitivity": lambda *args, **kwargs: [],
|
||||
"pressure_sensor_placement_kmeans": lambda *args, **kwargs: [],
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.network_import",
|
||||
{"network_update": lambda *args, **kwargs: "updated"},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.simulation_ops",
|
||||
{
|
||||
"project_management": lambda *args, **kwargs: "managed",
|
||||
"scheduling_simulation": lambda *args, **kwargs: "scheduled",
|
||||
"daily_scheduling_simulation": lambda *args, **kwargs: "daily",
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.valve_isolation",
|
||||
{"analyze_valve_isolation": lambda *args, **kwargs: {}},
|
||||
)
|
||||
return load_module_from_path(
|
||||
"tests_simulation_endpoints_module",
|
||||
"app/api/v1/endpoints/simulation.py",
|
||||
)
|
||||
|
||||
|
||||
def test_run_project_endpoint_returns_plain_text(monkeypatch):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
monkeypatch.setattr(module, "run_project", lambda network: f"report::{network}")
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.get("/api/v1/runproject/", params={"network": "demo"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == "report::demo"
|
||||
|
||||
|
||||
def test_scheduling_analysis_maps_request_body(monkeypatch):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_schedule(network, start_time, pump_control, tank_id, water_plant_output_id, time_delta):
|
||||
captured["args"] = (
|
||||
network,
|
||||
start_time,
|
||||
pump_control,
|
||||
tank_id,
|
||||
water_plant_output_id,
|
||||
time_delta,
|
||||
)
|
||||
return "scheduled"
|
||||
|
||||
monkeypatch.setattr(module, "scheduling_simulation", fake_schedule)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/scheduling_analysis/",
|
||||
json={
|
||||
"network": "demo",
|
||||
"start_time": "2025-01-01T08:00:00+08:00",
|
||||
"pump_control": {"P1": [1, 0, 1]},
|
||||
"tank_id": "T1",
|
||||
"water_plant_output_id": "R1",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "scheduled"
|
||||
assert captured["args"] == (
|
||||
"demo",
|
||||
"2025-01-01T08:00:00+08:00",
|
||||
{"P1": [1, 0, 1]},
|
||||
"T1",
|
||||
"R1",
|
||||
300,
|
||||
)
|
||||
|
||||
|
||||
def test_project_management_maps_named_arguments(monkeypatch):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_project_management(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return "managed"
|
||||
|
||||
monkeypatch.setattr(module, "project_management", fake_project_management)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/project_management/",
|
||||
json={
|
||||
"network": "demo",
|
||||
"start_time": "2025-01-01T08:00:00+08:00",
|
||||
"pump_control": {"P1": [1]},
|
||||
"tank_init_level": {"T1": 10.0},
|
||||
"region_demand": {"R1": 20.0},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "managed"
|
||||
assert captured == {
|
||||
"prj_name": "demo",
|
||||
"start_datetime": "2025-01-01T08:00:00+08:00",
|
||||
"pump_control": {"P1": [1]},
|
||||
"tank_initial_level_control": {"T1": 10.0},
|
||||
"region_demand_control": {"R1": 20.0},
|
||||
}
|
||||
|
||||
|
||||
def test_network_update_surfaces_service_error(monkeypatch, tmp_path):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
def boom(_path):
|
||||
raise RuntimeError("write failed")
|
||||
|
||||
monkeypatch.setattr(module, "network_update", boom)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/network_update/",
|
||||
files={"file": ("update.txt", b"payload")},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "数据库操作失败: write failed" in response.json()["detail"]
|
||||
assert list(Path(tmp_path).glob("network_update_*"))
|
||||
@@ -0,0 +1,95 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints import user_management as user_management_endpoint
|
||||
from app.auth.dependencies import get_current_active_user, get_user_repository
|
||||
from app.auth.permissions import get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
from tests.conftest import build_test_app, make_user
|
||||
|
||||
|
||||
def _build_client(repo, *, current_user=None, admin_user=None) -> TestClient:
|
||||
app = build_test_app(user_management_endpoint.router, "/users")
|
||||
app.dependency_overrides[get_user_repository] = lambda: repo
|
||||
if current_user is not None:
|
||||
app.dependency_overrides[get_current_active_user] = lambda: current_user
|
||||
if admin_user is not None:
|
||||
app.dependency_overrides[get_current_admin] = lambda: admin_user
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_list_users_requires_admin_role():
|
||||
repo = SimpleNamespace(
|
||||
get_all_users=AsyncMock(
|
||||
return_value=[
|
||||
make_user(id=1, username="admin", role=UserRole.ADMIN),
|
||||
make_user(id=2, username="user2"),
|
||||
]
|
||||
)
|
||||
)
|
||||
client = _build_client(
|
||||
repo,
|
||||
current_user=make_user(id=1, role=UserRole.ADMIN),
|
||||
)
|
||||
|
||||
response = client.get("/users/", params={"skip": 5, "limit": 2})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
repo.get_all_users.assert_awaited_once_with(skip=5, limit=2)
|
||||
|
||||
|
||||
def test_get_user_rejects_non_owner_non_admin():
|
||||
repo = SimpleNamespace(get_user_by_id=AsyncMock())
|
||||
client = _build_client(repo, current_user=make_user(id=2, role=UserRole.USER))
|
||||
|
||||
response = client.get("/users/3")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "You don't have permission to view this user"
|
||||
repo.get_user_by_id.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_user_blocks_role_change_for_non_admin():
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_id=AsyncMock(return_value=make_user(id=1)),
|
||||
update_user=AsyncMock(),
|
||||
)
|
||||
client = _build_client(repo, current_user=make_user(id=1, role=UserRole.USER))
|
||||
|
||||
response = client.put("/users/1", json={"role": "ADMIN"})
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Only admins can change user roles"
|
||||
repo.update_user.assert_not_awaited()
|
||||
|
||||
|
||||
def test_delete_user_blocks_self_delete_for_admin():
|
||||
admin_user = make_user(id=1, role=UserRole.ADMIN, is_superuser=True)
|
||||
repo = SimpleNamespace(delete_user=AsyncMock())
|
||||
client = _build_client(repo, admin_user=admin_user)
|
||||
|
||||
response = client.delete("/users/1")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "You cannot delete your own account"
|
||||
repo.delete_user.assert_not_awaited()
|
||||
|
||||
|
||||
def test_activate_user_updates_active_flag():
|
||||
repo = SimpleNamespace(
|
||||
update_user=AsyncMock(return_value=make_user(id=2, is_active=True)),
|
||||
)
|
||||
client = _build_client(
|
||||
repo,
|
||||
admin_user=make_user(id=1, role=UserRole.ADMIN, is_superuser=True),
|
||||
)
|
||||
|
||||
response = client.post("/users/2/activate")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["is_active"] is True
|
||||
user_update = repo.update_user.await_args.args[1]
|
||||
assert user_update.is_active is True
|
||||
@@ -0,0 +1,36 @@
|
||||
from jose import jwt
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
|
||||
def test_password_hash_roundtrip():
|
||||
hashed = get_password_hash("secret123")
|
||||
assert hashed != "secret123"
|
||||
assert verify_password("secret123", hashed) is True
|
||||
assert verify_password("wrong", hashed) is False
|
||||
|
||||
|
||||
def test_create_access_token_sets_access_type():
|
||||
token = create_access_token("alice")
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
|
||||
assert payload["sub"] == "alice"
|
||||
assert payload["type"] == "access"
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
|
||||
|
||||
def test_create_refresh_token_sets_refresh_type():
|
||||
token = create_refresh_token("alice")
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
|
||||
assert payload["sub"] == "alice"
|
||||
assert payload["type"] == "refresh"
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
+189
-6
@@ -1,14 +1,197 @@
|
||||
import pytest
|
||||
import sys
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
# 自动添加项目根目录到路径(处理项目结构)
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
def run_this_test(test_file):
|
||||
"""自定义函数:运行单个测试文件(类似pytest)"""
|
||||
# 提取测试文件名(无扩展名)
|
||||
test_name = os.path.splitext(os.path.basename(test_file))[0]
|
||||
# 使用pytest运行(自动处理导入)
|
||||
pytest.main([test_file, "-v"])
|
||||
|
||||
|
||||
def build_test_app(router, prefix: str = "") -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix=prefix)
|
||||
return app
|
||||
|
||||
|
||||
def load_module_from_path(module_name: str, relative_path: str):
|
||||
module_path = PROJECT_ROOT / relative_path
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def install_stub(monkeypatch, name: str, attrs: dict | None = None, package: bool = False):
|
||||
module = types.ModuleType(name)
|
||||
if package:
|
||||
module.__path__ = []
|
||||
if attrs:
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
parent_name, _, child_name = name.rpartition(".")
|
||||
if parent_name:
|
||||
parent = sys.modules.get(parent_name)
|
||||
if parent is None:
|
||||
try:
|
||||
parent = importlib.import_module(parent_name)
|
||||
except Exception:
|
||||
parent = types.ModuleType(parent_name)
|
||||
parent.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, parent_name, parent)
|
||||
setattr(parent, child_name, module)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class FakeCursor:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
fetchone_results=None,
|
||||
fetchall_results=None,
|
||||
rowcount: int = 0,
|
||||
rowcounts=None,
|
||||
):
|
||||
self._fetchone_results = list(fetchone_results or [])
|
||||
self._fetchall_results = list(fetchall_results or [])
|
||||
self._rowcounts = list(rowcounts or [])
|
||||
self.rowcount = rowcount
|
||||
self.executed: list[tuple[str, dict | tuple | None]] = []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def execute(self, query, params=None):
|
||||
self.executed.append((str(query), params))
|
||||
if self._rowcounts:
|
||||
self.rowcount = self._rowcounts.pop(0)
|
||||
|
||||
async def fetchone(self):
|
||||
if self._fetchone_results:
|
||||
return self._fetchone_results.pop(0)
|
||||
return None
|
||||
|
||||
async def fetchall(self):
|
||||
if self._fetchall_results:
|
||||
return self._fetchall_results.pop(0)
|
||||
return []
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
def __init__(self, cursor: FakeCursor):
|
||||
self.cursor_instance = cursor
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def cursor(self):
|
||||
return self.cursor_instance
|
||||
|
||||
|
||||
class FakeDB:
|
||||
def __init__(self, cursor: FakeCursor):
|
||||
self.connection = FakeConnection(cursor)
|
||||
|
||||
def get_connection(self):
|
||||
return self.connection
|
||||
|
||||
|
||||
class FakeExecuteResult:
|
||||
def __init__(self, *, rows=None, scalar_value=None):
|
||||
self._rows = list(rows or [])
|
||||
self._scalar_value = scalar_value
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return self._rows
|
||||
|
||||
def scalar(self):
|
||||
return self._scalar_value
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self, execute_results=None):
|
||||
self._execute_results = list(execute_results or [])
|
||||
self.executed = []
|
||||
self.added = []
|
||||
self.commit_count = 0
|
||||
self.refreshed = []
|
||||
|
||||
def add(self, obj):
|
||||
self.added.append(obj)
|
||||
|
||||
async def execute(self, stmt):
|
||||
self.executed.append(stmt)
|
||||
if self._execute_results:
|
||||
return self._execute_results.pop(0)
|
||||
return FakeExecuteResult()
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
async def refresh(self, obj):
|
||||
self.refreshed.append(obj)
|
||||
|
||||
|
||||
def make_user(**overrides):
|
||||
from app.domain.models.role import UserRole
|
||||
from app.domain.schemas.user import UserInDB
|
||||
|
||||
data = {
|
||||
"id": 1,
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"hashed_password": "hashed-password",
|
||||
"role": UserRole.USER,
|
||||
"is_active": True,
|
||||
"is_superuser": False,
|
||||
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
data.update(overrides)
|
||||
return UserInDB(**data)
|
||||
|
||||
|
||||
def make_audit_log(**overrides):
|
||||
data = {
|
||||
"id": uuid4(),
|
||||
"user_id": uuid4(),
|
||||
"project_id": uuid4(),
|
||||
"action": "LOGIN",
|
||||
"resource_type": "user",
|
||||
"resource_id": "1",
|
||||
"ip_address": "127.0.0.1",
|
||||
"request_method": "GET",
|
||||
"request_path": "/audit/logs",
|
||||
"request_data": {"ok": True},
|
||||
"response_status": 200,
|
||||
"timestamp": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
data.update(overrides)
|
||||
return SimpleNamespace(**data)
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.infra.db.metadb.repositories.audit_repository import AuditRepository
|
||||
from tests.conftest import FakeAsyncSession, FakeExecuteResult, make_audit_log
|
||||
|
||||
|
||||
def test_create_log_adds_commits_and_refreshes(monkeypatch):
|
||||
class FakeAuditLog:
|
||||
def __init__(self, **kwargs):
|
||||
self.id = uuid4()
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
session = FakeAsyncSession()
|
||||
repo = AuditRepository(session)
|
||||
monkeypatch.setattr(
|
||||
"app.infra.db.metadb.repositories.audit_repository.models.AuditLog",
|
||||
FakeAuditLog,
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
repo.create_log(
|
||||
action="LOGIN",
|
||||
request_method="POST",
|
||||
request_path="/auth/login",
|
||||
response_status=200,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.action == "LOGIN"
|
||||
assert result.request_method == "POST"
|
||||
assert session.commit_count == 1
|
||||
assert len(session.added) == 1
|
||||
assert len(session.refreshed) == 1
|
||||
|
||||
|
||||
def test_get_logs_builds_filtered_query_and_returns_models():
|
||||
log = make_audit_log(action="UPDATE_USER", resource_type="user")
|
||||
session = FakeAsyncSession(
|
||||
execute_results=[FakeExecuteResult(rows=[log])],
|
||||
)
|
||||
repo = AuditRepository(session)
|
||||
user_id = uuid4()
|
||||
project_id = uuid4()
|
||||
start_time = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
results = asyncio.run(
|
||||
repo.get_logs(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action="UPDATE_USER",
|
||||
resource_type="user",
|
||||
start_time=start_time,
|
||||
skip=5,
|
||||
limit=10,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].action == "UPDATE_USER"
|
||||
stmt = session.executed[0]
|
||||
assert len(stmt._where_criteria) == 5
|
||||
assert stmt._offset == 5
|
||||
assert stmt._limit == 10
|
||||
|
||||
|
||||
def test_get_log_count_returns_zero_when_scalar_none():
|
||||
session = FakeAsyncSession(
|
||||
execute_results=[FakeExecuteResult(scalar_value=None)],
|
||||
)
|
||||
repo = AuditRepository(session)
|
||||
|
||||
result = asyncio.run(repo.get_log_count(action="DELETE_USER"))
|
||||
|
||||
assert result == 0
|
||||
stmt = session.executed[0]
|
||||
assert len(stmt._where_criteria) == 1
|
||||
@@ -0,0 +1,97 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.auth import dependencies
|
||||
from app.core.security import create_access_token, create_refresh_token
|
||||
from tests.conftest import make_user
|
||||
|
||||
|
||||
def test_get_db_returns_app_state_db():
|
||||
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace(db="db-instance")))
|
||||
|
||||
result = asyncio.run(dependencies.get_db(request))
|
||||
|
||||
assert result == "db-instance"
|
||||
|
||||
|
||||
def test_get_db_raises_when_database_missing():
|
||||
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace()))
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(dependencies.get_db(request))
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "Database not initialized"
|
||||
|
||||
|
||||
def test_get_current_user_accepts_valid_access_token():
|
||||
repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=make_user()))
|
||||
|
||||
result = asyncio.run(
|
||||
dependencies.get_current_user(
|
||||
token=create_access_token("tester"),
|
||||
user_repo=repo,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.username == "tester"
|
||||
repo.get_user_by_username.assert_awaited_once_with("tester")
|
||||
|
||||
|
||||
def test_get_current_user_rejects_refresh_token():
|
||||
repo = SimpleNamespace(get_user_by_username=AsyncMock())
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
dependencies.get_current_user(
|
||||
token=create_refresh_token("tester"),
|
||||
user_repo=repo,
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid token type. Access token required."
|
||||
repo.get_user_by_username.assert_not_awaited()
|
||||
|
||||
|
||||
def test_get_current_user_rejects_missing_user():
|
||||
repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=None))
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
dependencies.get_current_user(
|
||||
token=create_access_token("ghost"),
|
||||
user_repo=repo,
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Could not validate credentials"
|
||||
|
||||
|
||||
def test_get_current_active_user_rejects_inactive_user():
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
dependencies.get_current_active_user(
|
||||
current_user=make_user(is_active=False),
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "Inactive user"
|
||||
|
||||
|
||||
def test_get_current_superuser_rejects_non_superuser():
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
dependencies.get_current_superuser(
|
||||
current_user=make_user(is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "Not enough privileges. Superuser access required."
|
||||
@@ -0,0 +1,56 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.auth import permissions
|
||||
from app.domain.models.role import UserRole
|
||||
from tests.conftest import make_user
|
||||
|
||||
|
||||
def test_require_role_allows_higher_privilege_user():
|
||||
checker = permissions.require_role(UserRole.OPERATOR)
|
||||
|
||||
result = asyncio.run(checker(current_user=make_user(role=UserRole.ADMIN)))
|
||||
|
||||
assert result.role == UserRole.ADMIN
|
||||
|
||||
|
||||
def test_require_role_rejects_insufficient_role():
|
||||
checker = permissions.require_role(UserRole.ADMIN)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(checker(current_user=make_user(role=UserRole.USER)))
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Required role: ADMIN" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_check_resource_owner_allows_admin():
|
||||
assert permissions.check_resource_owner(
|
||||
99,
|
||||
make_user(id=1, role=UserRole.ADMIN),
|
||||
) is True
|
||||
|
||||
|
||||
def test_check_resource_owner_allows_owner():
|
||||
assert permissions.check_resource_owner(
|
||||
7,
|
||||
make_user(id=7, role=UserRole.USER),
|
||||
) is True
|
||||
|
||||
|
||||
def test_check_resource_owner_rejects_other_user():
|
||||
assert permissions.check_resource_owner(
|
||||
7,
|
||||
make_user(id=8, role=UserRole.USER),
|
||||
) is False
|
||||
|
||||
|
||||
def test_require_owner_or_admin_rejects_other_user():
|
||||
checker = permissions.require_owner_or_admin(7)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(checker(current_user=make_user(id=8, role=UserRole.USER)))
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "You don't have permission to access this resource"
|
||||
@@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _load_scada_repository():
|
||||
module_path = (
|
||||
@@ -50,18 +49,19 @@ class _FakeConnection:
|
||||
return self.cursor_instance
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_scada_field_inserts_when_update_hits_no_rows():
|
||||
def test_update_scada_field_inserts_when_update_hits_no_rows():
|
||||
ScadaRepository = _load_scada_repository()
|
||||
conn = _FakeConnection(initial_rowcount=0)
|
||||
point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
await ScadaRepository.update_scada_field(
|
||||
conn,
|
||||
point_time,
|
||||
"170490",
|
||||
"cleaned_value",
|
||||
26.5,
|
||||
asyncio.run(
|
||||
ScadaRepository.update_scada_field(
|
||||
conn,
|
||||
point_time,
|
||||
"170490",
|
||||
"cleaned_value",
|
||||
26.5,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(conn.cursor_instance.calls) == 2
|
||||
@@ -69,18 +69,19 @@ async def test_update_scada_field_inserts_when_update_hits_no_rows():
|
||||
assert "INSERT INTO scada.scada_data" in conn.cursor_instance.calls[1][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_scada_field_skips_insert_when_update_succeeds():
|
||||
def test_update_scada_field_skips_insert_when_update_succeeds():
|
||||
ScadaRepository = _load_scada_repository()
|
||||
conn = _FakeConnection(initial_rowcount=1)
|
||||
point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
await ScadaRepository.update_scada_field(
|
||||
conn,
|
||||
point_time,
|
||||
"170490",
|
||||
"cleaned_value",
|
||||
26.5,
|
||||
asyncio.run(
|
||||
ScadaRepository.update_scada_field(
|
||||
conn,
|
||||
point_time,
|
||||
"170490",
|
||||
"cleaned_value",
|
||||
26.5,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(conn.cursor_instance.calls) == 1
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.domain.models.role import UserRole
|
||||
from app.domain.schemas.user import UserCreate, UserUpdate
|
||||
from app.infra.db.metadb.repositories.user_repository import UserRepository
|
||||
from tests.conftest import FakeCursor, FakeDB
|
||||
|
||||
|
||||
def _user_row(**overrides):
|
||||
base = {
|
||||
"id": 1,
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"hashed_password": "hashed-password",
|
||||
"role": "USER",
|
||||
"is_active": True,
|
||||
"is_superuser": False,
|
||||
"created_at": "2025-01-01T00:00:00+00:00",
|
||||
"updated_at": "2025-01-01T00:00:00+00:00",
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def test_create_user_hashes_password_and_returns_model(monkeypatch):
|
||||
cursor = FakeCursor(fetchone_results=[_user_row()])
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
monkeypatch.setattr(
|
||||
"app.infra.db.metadb.repositories.user_repository.get_password_hash",
|
||||
lambda password: f"hashed::{password}",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
repo.create_user(
|
||||
UserCreate(
|
||||
username="tester",
|
||||
email="tester@example.com",
|
||||
password="secret123",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.username == "tester"
|
||||
assert cursor.executed[0][1]["hashed_password"] == "hashed::secret123"
|
||||
|
||||
|
||||
def test_update_user_without_fields_returns_existing_user(monkeypatch):
|
||||
repo = UserRepository(FakeDB(FakeCursor()))
|
||||
existing_user = AsyncMock(return_value="existing")
|
||||
monkeypatch.setattr(repo, "get_user_by_id", existing_user)
|
||||
|
||||
result = asyncio.run(repo.update_user(1, UserUpdate()))
|
||||
|
||||
assert result == "existing"
|
||||
existing_user.assert_awaited_once_with(1)
|
||||
|
||||
|
||||
def test_update_user_builds_dynamic_query(monkeypatch):
|
||||
cursor = FakeCursor(fetchone_results=[_user_row(role="ADMIN", email="new@example.com")])
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
monkeypatch.setattr(
|
||||
"app.infra.db.metadb.repositories.user_repository.get_password_hash",
|
||||
lambda password: f"hashed::{password}",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
repo.update_user(
|
||||
1,
|
||||
UserUpdate(
|
||||
email="new@example.com",
|
||||
password="new-secret",
|
||||
role=UserRole.ADMIN,
|
||||
is_active=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
query, params = cursor.executed[0]
|
||||
assert "email = %(email)s" in query
|
||||
assert "hashed_password = %(hashed_password)s" in query
|
||||
assert "role = %(role)s" in query
|
||||
assert "is_active = %(is_active)s" in query
|
||||
assert params["hashed_password"] == "hashed::new-secret"
|
||||
assert params["role"] == "ADMIN"
|
||||
assert params["is_active"] is False
|
||||
|
||||
|
||||
def test_delete_user_returns_false_when_execute_raises():
|
||||
cursor = FakeCursor()
|
||||
cursor.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
|
||||
result = asyncio.run(repo.delete_user(1))
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_user_exists_short_circuits_without_filters():
|
||||
cursor = FakeCursor()
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
|
||||
result = asyncio.run(repo.user_exists())
|
||||
|
||||
assert result is False
|
||||
assert cursor.executed == []
|
||||
|
||||
|
||||
def test_user_exists_checks_username_or_email():
|
||||
cursor = FakeCursor(fetchone_results=[{"exists": True}])
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
|
||||
result = asyncio.run(
|
||||
repo.user_exists(username="tester", email="tester@example.com")
|
||||
)
|
||||
|
||||
assert result is True
|
||||
query, params = cursor.executed[0]
|
||||
assert "username = %(username)s OR email = %(email)s" in query
|
||||
assert params == {"username": "tester", "email": "tester@example.com"}
|
||||
Reference in New Issue
Block a user