新增 API 测试用例,修复失效接口问题
This commit is contained in:
@@ -0,0 +1,91 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints import audit as audit_endpoint
|
||||
from app.auth.metadata_dependencies import (
|
||||
get_current_metadata_admin,
|
||||
get_current_metadata_user,
|
||||
)
|
||||
from tests.conftest import build_test_app, make_audit_log
|
||||
|
||||
|
||||
def _build_client(repo, *, metadata_admin=None, metadata_user=None) -> TestClient:
|
||||
app = build_test_app(audit_endpoint.router, "/audit")
|
||||
app.dependency_overrides[audit_endpoint.get_audit_repository] = lambda: repo
|
||||
if metadata_admin is not None:
|
||||
app.dependency_overrides[get_current_metadata_admin] = lambda: metadata_admin
|
||||
if metadata_user is not None:
|
||||
app.dependency_overrides[get_current_metadata_user] = lambda: metadata_user
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_get_audit_logs_passes_filters():
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_logs": AsyncMock(return_value=[make_audit_log(action="LOGIN")]),
|
||||
"get_log_count": AsyncMock(),
|
||||
},
|
||||
)()
|
||||
client = _build_client(repo, metadata_admin=object())
|
||||
|
||||
response = client.get(
|
||||
"/audit/logs",
|
||||
params={
|
||||
"action": "LOGIN",
|
||||
"resource_type": "user",
|
||||
"skip": 2,
|
||||
"limit": 5,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()[0]["action"] == "LOGIN"
|
||||
repo.get_logs.assert_awaited_once()
|
||||
kwargs = repo.get_logs.await_args.kwargs
|
||||
assert kwargs["action"] == "LOGIN"
|
||||
assert kwargs["resource_type"] == "user"
|
||||
assert kwargs["skip"] == 2
|
||||
assert kwargs["limit"] == 5
|
||||
|
||||
|
||||
def test_get_audit_logs_count_returns_count_payload():
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_logs": AsyncMock(),
|
||||
"get_log_count": AsyncMock(return_value=7),
|
||||
},
|
||||
)()
|
||||
client = _build_client(repo, metadata_admin=object())
|
||||
|
||||
response = client.get("/audit/logs/count", params={"action": "DELETE_USER"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"count": 7}
|
||||
repo.get_log_count.assert_awaited_once()
|
||||
assert repo.get_log_count.await_args.kwargs["action"] == "DELETE_USER"
|
||||
|
||||
|
||||
def test_get_my_audit_logs_forces_current_user_id():
|
||||
current_user = type("User", (), {"id": make_audit_log().user_id})()
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_logs": AsyncMock(return_value=[make_audit_log(user_id=current_user.id)]),
|
||||
"get_log_count": AsyncMock(),
|
||||
},
|
||||
)()
|
||||
client = _build_client(repo, metadata_user=current_user)
|
||||
|
||||
response = client.get("/audit/logs/my", params={"limit": 3})
|
||||
|
||||
assert response.status_code == 200
|
||||
repo.get_logs.assert_awaited_once()
|
||||
kwargs = repo.get_logs.await_args.kwargs
|
||||
assert kwargs["user_id"] == current_user.id
|
||||
assert kwargs["limit"] == 3
|
||||
@@ -0,0 +1,139 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints import auth as auth_endpoint
|
||||
from app.auth.dependencies import get_current_active_user, get_user_repository
|
||||
from app.core.security import create_access_token, create_refresh_token, get_password_hash
|
||||
from tests.conftest import build_test_app, make_user
|
||||
|
||||
|
||||
def _build_client(repo, current_user=None) -> TestClient:
|
||||
app = build_test_app(auth_endpoint.router, "/api/v1/auth")
|
||||
app.dependency_overrides[get_user_repository] = lambda: repo
|
||||
if current_user is not None:
|
||||
app.dependency_overrides[get_current_active_user] = lambda: current_user
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_register_success():
|
||||
repo = SimpleNamespace(
|
||||
user_exists=AsyncMock(side_effect=[False, False]),
|
||||
create_user=AsyncMock(return_value=make_user()),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["username"] == "tester"
|
||||
|
||||
|
||||
def test_register_rejects_duplicate_username():
|
||||
repo = SimpleNamespace(
|
||||
user_exists=AsyncMock(side_effect=[True]),
|
||||
create_user=AsyncMock(),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Username already registered"
|
||||
repo.create_user.assert_not_awaited()
|
||||
|
||||
|
||||
def test_login_supports_email_lookup():
|
||||
hashed_password = get_password_hash("secret123")
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_username=AsyncMock(return_value=None),
|
||||
get_user_by_email=AsyncMock(
|
||||
return_value=make_user(
|
||||
email="tester@example.com",
|
||||
hashed_password=hashed_password,
|
||||
)
|
||||
),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "tester@example.com", "password": "secret123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
repo.get_user_by_email.assert_awaited_once_with("tester@example.com")
|
||||
|
||||
|
||||
def test_login_simple_uses_query_params():
|
||||
hashed_password = get_password_hash("secret123")
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_username=AsyncMock(
|
||||
return_value=make_user(hashed_password=hashed_password)
|
||||
),
|
||||
get_user_by_email=AsyncMock(),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/simple",
|
||||
params={"username": "tester", "password": "secret123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
|
||||
|
||||
def test_me_returns_current_user_info():
|
||||
client = _build_client(SimpleNamespace(), current_user=make_user(username="alice"))
|
||||
|
||||
response = client.get("/api/v1/auth/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["username"] == "alice"
|
||||
|
||||
|
||||
def test_refresh_rejects_access_token():
|
||||
repo = SimpleNamespace(get_user_by_username=AsyncMock())
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
params={"refresh_token": create_access_token("tester")},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_refresh_success_returns_new_access_token():
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_username=AsyncMock(return_value=make_user()),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
refresh_token = create_refresh_token("tester")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
params={"refresh_token": refresh_token},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["refresh_token"] == refresh_token
|
||||
assert payload["token_type"] == "bearer"
|
||||
@@ -0,0 +1,152 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tests.conftest import build_test_app, install_stub, load_module_from_path
|
||||
|
||||
|
||||
class DummyChangeSet:
|
||||
def __init__(self, operations=None):
|
||||
if operations is None:
|
||||
self.operations = []
|
||||
elif isinstance(operations, dict):
|
||||
self.operations = [operations]
|
||||
else:
|
||||
self.operations = operations
|
||||
|
||||
|
||||
def _load_project_module(monkeypatch):
|
||||
install_stub(monkeypatch, "app.services", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.project_info",
|
||||
{},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.tjnetwork",
|
||||
{
|
||||
"ChangeSet": DummyChangeSet,
|
||||
"list_project": lambda: ["demo"],
|
||||
"have_project": lambda network: network == "demo",
|
||||
"create_project": lambda network: None,
|
||||
"delete_project": lambda network: None,
|
||||
"is_project_open": lambda network: False,
|
||||
"open_project": lambda network: None,
|
||||
"close_project": lambda network: None,
|
||||
"copy_project": lambda source, target: None,
|
||||
"import_inp": lambda network, cs: {"ok": True},
|
||||
"export_inp": lambda network, version: DummyChangeSet({"kind": "export"}),
|
||||
"read_inp": lambda network, inp: True,
|
||||
"dump_inp": lambda network, inp: True,
|
||||
"get_all_vertices": lambda network: [],
|
||||
"get_all_scada_elements": lambda network: [],
|
||||
"get_all_district_metering_areas": lambda network: [],
|
||||
"get_all_service_areas": lambda network: [],
|
||||
"get_all_virtual_districts": lambda network: [],
|
||||
"get_extension_data": lambda network, key: None,
|
||||
"convert_inp_v3_to_v2": lambda inp: DummyChangeSet({"inp": inp}),
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.auth.project_dependencies",
|
||||
{"get_metadata_repository": lambda: None},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.infra.db.postgresql.database",
|
||||
{"get_database_instance": lambda network: None},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.infra.db.timescaledb.database",
|
||||
{"get_database_instance": lambda network: None},
|
||||
)
|
||||
return load_module_from_path(
|
||||
"tests_project_endpoints_module",
|
||||
"app/api/v1/endpoints/project.py",
|
||||
)
|
||||
|
||||
|
||||
def test_project_info_returns_404_when_missing(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=None))
|
||||
app = build_test_app(module.router, "/api/v1")
|
||||
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/project_info/", params={"network": "missing"})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Project missing not found"
|
||||
|
||||
|
||||
def test_project_info_returns_geoserver_payload(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
detail = SimpleNamespace(
|
||||
project_id=uuid4(),
|
||||
name="Demo Project",
|
||||
code="demo",
|
||||
description="desc",
|
||||
gs_workspace="ws",
|
||||
map_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
|
||||
status="active",
|
||||
geoserver=SimpleNamespace(
|
||||
gs_base_url="http://gs",
|
||||
gs_admin_user="admin",
|
||||
gs_datastore_name="store",
|
||||
default_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
|
||||
srid=4326,
|
||||
),
|
||||
)
|
||||
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=detail))
|
||||
app = build_test_app(module.router, "/api/v1")
|
||||
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/project_info/", params={"network": "demo"})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["code"] == "demo"
|
||||
assert payload["geoserver"]["gs_base_url"] == "http://gs"
|
||||
|
||||
|
||||
def test_open_project_returns_network_even_when_db_connection_fails(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
called = []
|
||||
|
||||
monkeypatch.setattr(module, "open_project", lambda network: called.append(network))
|
||||
|
||||
async def failing_get_pg_db(network):
|
||||
raise RuntimeError("db down")
|
||||
|
||||
monkeypatch.setattr(module, "get_pg_db", failing_get_pg_db)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post("/api/v1/openproject/", params={"network": "demo"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "demo"
|
||||
assert called == ["demo"]
|
||||
|
||||
|
||||
def test_project_lock_lifecycle(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
module.lockedPrjs.clear()
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
first_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
|
||||
second_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
|
||||
locked_by_me = client.get("/api/v1/isprojectlockedbyme/", params={"network": "demo"})
|
||||
unlock = client.post("/api/v1/unlockproject/", params={"network": "demo"})
|
||||
locked = client.get("/api/v1/isprojectlocked/", params={"network": "demo"})
|
||||
|
||||
assert first_lock.json() == 0
|
||||
assert second_lock.json() == 1
|
||||
assert locked_by_me.json() is True
|
||||
assert unlock.json() is True
|
||||
assert locked.json() is False
|
||||
@@ -0,0 +1,154 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tests.conftest import build_test_app, install_stub, load_module_from_path
|
||||
|
||||
|
||||
class DummyChangeSet:
|
||||
def __init__(self, operations=None):
|
||||
if operations is None:
|
||||
self.operations = []
|
||||
elif isinstance(operations, dict):
|
||||
self.operations = [operations]
|
||||
else:
|
||||
self.operations = operations
|
||||
|
||||
|
||||
def _noop(*args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def _load_regions_module(monkeypatch):
|
||||
install_stub(monkeypatch, "app.services", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.tjnetwork",
|
||||
{
|
||||
"Any": Any,
|
||||
"ChangeSet": DummyChangeSet,
|
||||
"add_district_metering_area": _noop,
|
||||
"add_region": _noop,
|
||||
"add_service_area": _noop,
|
||||
"add_virtual_district": _noop,
|
||||
"calculate_district_metering_area_for_network": lambda *args, **kwargs: [],
|
||||
"calculate_district_metering_area_for_nodes": lambda *args, **kwargs: [],
|
||||
"calculate_district_metering_area_for_region": lambda *args, **kwargs: [],
|
||||
"calculate_service_area": lambda network: [],
|
||||
"calculate_virtual_district": lambda *args, **kwargs: {},
|
||||
"delete_district_metering_area": _noop,
|
||||
"delete_region": _noop,
|
||||
"delete_service_area": _noop,
|
||||
"delete_virtual_district": _noop,
|
||||
"generate_district_metering_area": _noop,
|
||||
"generate_service_area": _noop,
|
||||
"generate_sub_district_metering_area": _noop,
|
||||
"generate_virtual_district": _noop,
|
||||
"get_all_district_metering_area_ids": lambda network: [],
|
||||
"get_all_district_metering_areas": lambda network: [],
|
||||
"get_all_service_areas": lambda network: [],
|
||||
"get_all_virtual_districts": lambda network: [],
|
||||
"get_district_metering_area": lambda network, area_id: {},
|
||||
"get_district_metering_area_schema": lambda network: {},
|
||||
"get_region": lambda network, region_id: {},
|
||||
"get_region_schema": lambda network: {},
|
||||
"get_service_area": lambda network, area_id: {},
|
||||
"get_service_area_schema": lambda network: {},
|
||||
"get_virtual_district": lambda network, area_id: {},
|
||||
"get_virtual_district_schema": lambda network: {},
|
||||
"set_district_metering_area": _noop,
|
||||
"set_region": _noop,
|
||||
"set_service_area": _noop,
|
||||
"set_virtual_district": _noop,
|
||||
},
|
||||
)
|
||||
return load_module_from_path(
|
||||
"tests_regions_endpoints_module",
|
||||
"app/api/v1/endpoints/network/regions.py",
|
||||
)
|
||||
|
||||
|
||||
def test_removed_routes_are_absent_and_return_404(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
openapi = client.get("/openapi.json").json()
|
||||
|
||||
assert "/api/v1/calculateregion/" not in openapi["paths"]
|
||||
assert "/api/v1/getallregions/" not in openapi["paths"]
|
||||
assert "/api/v1/generateregion/" not in openapi["paths"]
|
||||
assert "/api/v1/calculatedistrictmeteringarea/" not in openapi["paths"]
|
||||
assert client.get("/api/v1/calculateregion/", params={"network": "demo", "time_index": 0}).status_code == 404
|
||||
assert client.get("/api/v1/calculatedistrictmeteringarea/", params={"network": "demo"}).status_code == 404
|
||||
|
||||
|
||||
def test_calculate_service_area_contract_uses_only_network(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"calculate_service_area",
|
||||
lambda network: calls.append(network) or [{"source-1": ["n1", "n2"]}],
|
||||
)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/calculateservicearea/",
|
||||
params={"network": "demo", "time_index": 5},
|
||||
)
|
||||
schema = client.get("/openapi.json").json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [{"source-1": ["n1", "n2"]}]
|
||||
assert calls == ["demo"]
|
||||
parameter_names = [
|
||||
item["name"]
|
||||
for item in schema["paths"]["/api/v1/calculateservicearea/"]["get"]["parameters"]
|
||||
]
|
||||
assert parameter_names == ["network"]
|
||||
|
||||
|
||||
def test_add_district_metering_area_converts_boundary_to_tuples(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_add(network, change_set):
|
||||
captured["network"] = network
|
||||
captured["boundary"] = change_set.operations[0]["boundary"]
|
||||
return {"ok": True}
|
||||
|
||||
monkeypatch.setattr(module, "add_district_metering_area", fake_add)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/adddistrictmeteringarea/",
|
||||
params={"network": "demo"},
|
||||
json={"id": "dma-1", "boundary": [[1, 2], [3, 4], [1, 2]]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert captured == {
|
||||
"network": "demo",
|
||||
"boundary": [(1, 2), (3, 4), (1, 2)],
|
||||
}
|
||||
|
||||
|
||||
def test_generate_virtual_district_reads_centers_from_body(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_generate(network, centers, inflate_delta):
|
||||
captured["args"] = (network, centers, inflate_delta)
|
||||
return {"generated": True}
|
||||
|
||||
monkeypatch.setattr(module, "generate_virtual_district", fake_generate)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/generatevirtualdistrict/",
|
||||
params={"network": "demo", "inflate_delta": 0.75},
|
||||
json={"centers": ["J1", "J2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert captured["args"] == ("demo", ["J1", "J2"], 0.75)
|
||||
@@ -0,0 +1,175 @@
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tests.conftest import build_test_app, install_stub, load_module_from_path
|
||||
|
||||
|
||||
def _load_simulation_module(monkeypatch):
|
||||
install_stub(monkeypatch, "app.services", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.simulation",
|
||||
{"run_simulation": lambda **kwargs: None},
|
||||
)
|
||||
install_stub(monkeypatch, "app.services.globals", {})
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.tjnetwork",
|
||||
{
|
||||
"run_project": lambda network: "report",
|
||||
"run_project_return_dict": lambda network: {"output": {}, "report": "ok"},
|
||||
"run_inp": lambda network: "inp-report",
|
||||
"dump_output": lambda output: f"dump::{output}",
|
||||
},
|
||||
)
|
||||
install_stub(monkeypatch, "app.algorithms", package=True)
|
||||
install_stub(monkeypatch, "app.algorithms.simulation", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.algorithms.simulation.scenarios",
|
||||
{
|
||||
"burst_analysis": lambda *args, **kwargs: "burst",
|
||||
"valve_close_analysis": lambda *args, **kwargs: "valve",
|
||||
"flushing_analysis": lambda *args, **kwargs: "flush",
|
||||
"contaminant_simulation": lambda *args, **kwargs: "contaminant",
|
||||
"age_analysis": lambda *args, **kwargs: "age",
|
||||
"pressure_regulation": lambda *args, **kwargs: "pressure",
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.algorithms.sensor",
|
||||
{
|
||||
"pressure_sensor_placement_sensitivity": lambda *args, **kwargs: [],
|
||||
"pressure_sensor_placement_kmeans": lambda *args, **kwargs: [],
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.network_import",
|
||||
{"network_update": lambda *args, **kwargs: "updated"},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.simulation_ops",
|
||||
{
|
||||
"project_management": lambda *args, **kwargs: "managed",
|
||||
"scheduling_simulation": lambda *args, **kwargs: "scheduled",
|
||||
"daily_scheduling_simulation": lambda *args, **kwargs: "daily",
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.valve_isolation",
|
||||
{"analyze_valve_isolation": lambda *args, **kwargs: {}},
|
||||
)
|
||||
return load_module_from_path(
|
||||
"tests_simulation_endpoints_module",
|
||||
"app/api/v1/endpoints/simulation.py",
|
||||
)
|
||||
|
||||
|
||||
def test_run_project_endpoint_returns_plain_text(monkeypatch):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
monkeypatch.setattr(module, "run_project", lambda network: f"report::{network}")
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.get("/api/v1/runproject/", params={"network": "demo"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == "report::demo"
|
||||
|
||||
|
||||
def test_scheduling_analysis_maps_request_body(monkeypatch):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_schedule(network, start_time, pump_control, tank_id, water_plant_output_id, time_delta):
|
||||
captured["args"] = (
|
||||
network,
|
||||
start_time,
|
||||
pump_control,
|
||||
tank_id,
|
||||
water_plant_output_id,
|
||||
time_delta,
|
||||
)
|
||||
return "scheduled"
|
||||
|
||||
monkeypatch.setattr(module, "scheduling_simulation", fake_schedule)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/scheduling_analysis/",
|
||||
json={
|
||||
"network": "demo",
|
||||
"start_time": "2025-01-01T08:00:00+08:00",
|
||||
"pump_control": {"P1": [1, 0, 1]},
|
||||
"tank_id": "T1",
|
||||
"water_plant_output_id": "R1",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "scheduled"
|
||||
assert captured["args"] == (
|
||||
"demo",
|
||||
"2025-01-01T08:00:00+08:00",
|
||||
{"P1": [1, 0, 1]},
|
||||
"T1",
|
||||
"R1",
|
||||
300,
|
||||
)
|
||||
|
||||
|
||||
def test_project_management_maps_named_arguments(monkeypatch):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_project_management(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return "managed"
|
||||
|
||||
monkeypatch.setattr(module, "project_management", fake_project_management)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/project_management/",
|
||||
json={
|
||||
"network": "demo",
|
||||
"start_time": "2025-01-01T08:00:00+08:00",
|
||||
"pump_control": {"P1": [1]},
|
||||
"tank_init_level": {"T1": 10.0},
|
||||
"region_demand": {"R1": 20.0},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "managed"
|
||||
assert captured == {
|
||||
"prj_name": "demo",
|
||||
"start_datetime": "2025-01-01T08:00:00+08:00",
|
||||
"pump_control": {"P1": [1]},
|
||||
"tank_initial_level_control": {"T1": 10.0},
|
||||
"region_demand_control": {"R1": 20.0},
|
||||
}
|
||||
|
||||
|
||||
def test_network_update_surfaces_service_error(monkeypatch, tmp_path):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
def boom(_path):
|
||||
raise RuntimeError("write failed")
|
||||
|
||||
monkeypatch.setattr(module, "network_update", boom)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/network_update/",
|
||||
files={"file": ("update.txt", b"payload")},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "数据库操作失败: write failed" in response.json()["detail"]
|
||||
assert list(Path(tmp_path).glob("network_update_*"))
|
||||
@@ -0,0 +1,95 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints import user_management as user_management_endpoint
|
||||
from app.auth.dependencies import get_current_active_user, get_user_repository
|
||||
from app.auth.permissions import get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
from tests.conftest import build_test_app, make_user
|
||||
|
||||
|
||||
def _build_client(repo, *, current_user=None, admin_user=None) -> TestClient:
|
||||
app = build_test_app(user_management_endpoint.router, "/users")
|
||||
app.dependency_overrides[get_user_repository] = lambda: repo
|
||||
if current_user is not None:
|
||||
app.dependency_overrides[get_current_active_user] = lambda: current_user
|
||||
if admin_user is not None:
|
||||
app.dependency_overrides[get_current_admin] = lambda: admin_user
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_list_users_requires_admin_role():
|
||||
repo = SimpleNamespace(
|
||||
get_all_users=AsyncMock(
|
||||
return_value=[
|
||||
make_user(id=1, username="admin", role=UserRole.ADMIN),
|
||||
make_user(id=2, username="user2"),
|
||||
]
|
||||
)
|
||||
)
|
||||
client = _build_client(
|
||||
repo,
|
||||
current_user=make_user(id=1, role=UserRole.ADMIN),
|
||||
)
|
||||
|
||||
response = client.get("/users/", params={"skip": 5, "limit": 2})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
repo.get_all_users.assert_awaited_once_with(skip=5, limit=2)
|
||||
|
||||
|
||||
def test_get_user_rejects_non_owner_non_admin():
|
||||
repo = SimpleNamespace(get_user_by_id=AsyncMock())
|
||||
client = _build_client(repo, current_user=make_user(id=2, role=UserRole.USER))
|
||||
|
||||
response = client.get("/users/3")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "You don't have permission to view this user"
|
||||
repo.get_user_by_id.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_user_blocks_role_change_for_non_admin():
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_id=AsyncMock(return_value=make_user(id=1)),
|
||||
update_user=AsyncMock(),
|
||||
)
|
||||
client = _build_client(repo, current_user=make_user(id=1, role=UserRole.USER))
|
||||
|
||||
response = client.put("/users/1", json={"role": "ADMIN"})
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Only admins can change user roles"
|
||||
repo.update_user.assert_not_awaited()
|
||||
|
||||
|
||||
def test_delete_user_blocks_self_delete_for_admin():
|
||||
admin_user = make_user(id=1, role=UserRole.ADMIN, is_superuser=True)
|
||||
repo = SimpleNamespace(delete_user=AsyncMock())
|
||||
client = _build_client(repo, admin_user=admin_user)
|
||||
|
||||
response = client.delete("/users/1")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "You cannot delete your own account"
|
||||
repo.delete_user.assert_not_awaited()
|
||||
|
||||
|
||||
def test_activate_user_updates_active_flag():
|
||||
repo = SimpleNamespace(
|
||||
update_user=AsyncMock(return_value=make_user(id=2, is_active=True)),
|
||||
)
|
||||
client = _build_client(
|
||||
repo,
|
||||
admin_user=make_user(id=1, role=UserRole.ADMIN, is_superuser=True),
|
||||
)
|
||||
|
||||
response = client.post("/users/2/activate")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["is_active"] is True
|
||||
user_update = repo.update_user.await_args.args[1]
|
||||
assert user_update.is_active is True
|
||||
Reference in New Issue
Block a user