新增 API 测试用例,修复失效接口问题

This commit is contained in:
2026-05-21 15:32:12 +08:00
parent 751950e5b5
commit 2317f4d527
15 changed files with 1486 additions and 96 deletions
+91
View File
@@ -0,0 +1,91 @@
from unittest.mock import AsyncMock
from fastapi.testclient import TestClient
from app.api.v1.endpoints import audit as audit_endpoint
from app.auth.metadata_dependencies import (
get_current_metadata_admin,
get_current_metadata_user,
)
from tests.conftest import build_test_app, make_audit_log
def _build_client(repo, *, metadata_admin=None, metadata_user=None) -> TestClient:
app = build_test_app(audit_endpoint.router, "/audit")
app.dependency_overrides[audit_endpoint.get_audit_repository] = lambda: repo
if metadata_admin is not None:
app.dependency_overrides[get_current_metadata_admin] = lambda: metadata_admin
if metadata_user is not None:
app.dependency_overrides[get_current_metadata_user] = lambda: metadata_user
return TestClient(app)
def test_get_audit_logs_passes_filters():
repo = type(
"Repo",
(),
{
"get_logs": AsyncMock(return_value=[make_audit_log(action="LOGIN")]),
"get_log_count": AsyncMock(),
},
)()
client = _build_client(repo, metadata_admin=object())
response = client.get(
"/audit/logs",
params={
"action": "LOGIN",
"resource_type": "user",
"skip": 2,
"limit": 5,
},
)
assert response.status_code == 200
assert response.json()[0]["action"] == "LOGIN"
repo.get_logs.assert_awaited_once()
kwargs = repo.get_logs.await_args.kwargs
assert kwargs["action"] == "LOGIN"
assert kwargs["resource_type"] == "user"
assert kwargs["skip"] == 2
assert kwargs["limit"] == 5
def test_get_audit_logs_count_returns_count_payload():
repo = type(
"Repo",
(),
{
"get_logs": AsyncMock(),
"get_log_count": AsyncMock(return_value=7),
},
)()
client = _build_client(repo, metadata_admin=object())
response = client.get("/audit/logs/count", params={"action": "DELETE_USER"})
assert response.status_code == 200
assert response.json() == {"count": 7}
repo.get_log_count.assert_awaited_once()
assert repo.get_log_count.await_args.kwargs["action"] == "DELETE_USER"
def test_get_my_audit_logs_forces_current_user_id():
current_user = type("User", (), {"id": make_audit_log().user_id})()
repo = type(
"Repo",
(),
{
"get_logs": AsyncMock(return_value=[make_audit_log(user_id=current_user.id)]),
"get_log_count": AsyncMock(),
},
)()
client = _build_client(repo, metadata_user=current_user)
response = client.get("/audit/logs/my", params={"limit": 3})
assert response.status_code == 200
repo.get_logs.assert_awaited_once()
kwargs = repo.get_logs.await_args.kwargs
assert kwargs["user_id"] == current_user.id
assert kwargs["limit"] == 3
+139
View File
@@ -0,0 +1,139 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from fastapi.testclient import TestClient
from app.api.v1.endpoints import auth as auth_endpoint
from app.auth.dependencies import get_current_active_user, get_user_repository
from app.core.security import create_access_token, create_refresh_token, get_password_hash
from tests.conftest import build_test_app, make_user
def _build_client(repo, current_user=None) -> TestClient:
app = build_test_app(auth_endpoint.router, "/api/v1/auth")
app.dependency_overrides[get_user_repository] = lambda: repo
if current_user is not None:
app.dependency_overrides[get_current_active_user] = lambda: current_user
return TestClient(app)
def test_register_success():
repo = SimpleNamespace(
user_exists=AsyncMock(side_effect=[False, False]),
create_user=AsyncMock(return_value=make_user()),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/register",
json={
"username": "tester",
"email": "tester@example.com",
"password": "secret123",
},
)
assert response.status_code == 201
assert response.json()["username"] == "tester"
def test_register_rejects_duplicate_username():
repo = SimpleNamespace(
user_exists=AsyncMock(side_effect=[True]),
create_user=AsyncMock(),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/register",
json={
"username": "tester",
"email": "tester@example.com",
"password": "secret123",
},
)
assert response.status_code == 400
assert response.json()["detail"] == "Username already registered"
repo.create_user.assert_not_awaited()
def test_login_supports_email_lookup():
hashed_password = get_password_hash("secret123")
repo = SimpleNamespace(
get_user_by_username=AsyncMock(return_value=None),
get_user_by_email=AsyncMock(
return_value=make_user(
email="tester@example.com",
hashed_password=hashed_password,
)
),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/login",
data={"username": "tester@example.com", "password": "secret123"},
)
assert response.status_code == 200
assert response.json()["token_type"] == "bearer"
repo.get_user_by_email.assert_awaited_once_with("tester@example.com")
def test_login_simple_uses_query_params():
hashed_password = get_password_hash("secret123")
repo = SimpleNamespace(
get_user_by_username=AsyncMock(
return_value=make_user(hashed_password=hashed_password)
),
get_user_by_email=AsyncMock(),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/login/simple",
params={"username": "tester", "password": "secret123"},
)
assert response.status_code == 200
assert response.json()["token_type"] == "bearer"
def test_me_returns_current_user_info():
client = _build_client(SimpleNamespace(), current_user=make_user(username="alice"))
response = client.get("/api/v1/auth/me")
assert response.status_code == 200
assert response.json()["username"] == "alice"
def test_refresh_rejects_access_token():
repo = SimpleNamespace(get_user_by_username=AsyncMock())
client = _build_client(repo)
response = client.post(
"/api/v1/auth/refresh",
params={"refresh_token": create_access_token("tester")},
)
assert response.status_code == 401
def test_refresh_success_returns_new_access_token():
repo = SimpleNamespace(
get_user_by_username=AsyncMock(return_value=make_user()),
)
client = _build_client(repo)
refresh_token = create_refresh_token("tester")
response = client.post(
"/api/v1/auth/refresh",
params={"refresh_token": refresh_token},
)
assert response.status_code == 200
payload = response.json()
assert payload["refresh_token"] == refresh_token
assert payload["token_type"] == "bearer"
+152
View File
@@ -0,0 +1,152 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from uuid import uuid4
from fastapi.testclient import TestClient
from tests.conftest import build_test_app, install_stub, load_module_from_path
class DummyChangeSet:
def __init__(self, operations=None):
if operations is None:
self.operations = []
elif isinstance(operations, dict):
self.operations = [operations]
else:
self.operations = operations
def _load_project_module(monkeypatch):
install_stub(monkeypatch, "app.services", package=True)
install_stub(
monkeypatch,
"app.services.project_info",
{},
)
install_stub(
monkeypatch,
"app.services.tjnetwork",
{
"ChangeSet": DummyChangeSet,
"list_project": lambda: ["demo"],
"have_project": lambda network: network == "demo",
"create_project": lambda network: None,
"delete_project": lambda network: None,
"is_project_open": lambda network: False,
"open_project": lambda network: None,
"close_project": lambda network: None,
"copy_project": lambda source, target: None,
"import_inp": lambda network, cs: {"ok": True},
"export_inp": lambda network, version: DummyChangeSet({"kind": "export"}),
"read_inp": lambda network, inp: True,
"dump_inp": lambda network, inp: True,
"get_all_vertices": lambda network: [],
"get_all_scada_elements": lambda network: [],
"get_all_district_metering_areas": lambda network: [],
"get_all_service_areas": lambda network: [],
"get_all_virtual_districts": lambda network: [],
"get_extension_data": lambda network, key: None,
"convert_inp_v3_to_v2": lambda inp: DummyChangeSet({"inp": inp}),
},
)
install_stub(
monkeypatch,
"app.auth.project_dependencies",
{"get_metadata_repository": lambda: None},
)
install_stub(
monkeypatch,
"app.infra.db.postgresql.database",
{"get_database_instance": lambda network: None},
)
install_stub(
monkeypatch,
"app.infra.db.timescaledb.database",
{"get_database_instance": lambda network: None},
)
return load_module_from_path(
"tests_project_endpoints_module",
"app/api/v1/endpoints/project.py",
)
def test_project_info_returns_404_when_missing(monkeypatch):
module = _load_project_module(monkeypatch)
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=None))
app = build_test_app(module.router, "/api/v1")
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
client = TestClient(app)
response = client.get("/api/v1/project_info/", params={"network": "missing"})
assert response.status_code == 404
assert response.json()["detail"] == "Project missing not found"
def test_project_info_returns_geoserver_payload(monkeypatch):
module = _load_project_module(monkeypatch)
detail = SimpleNamespace(
project_id=uuid4(),
name="Demo Project",
code="demo",
description="desc",
gs_workspace="ws",
map_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
status="active",
geoserver=SimpleNamespace(
gs_base_url="http://gs",
gs_admin_user="admin",
gs_datastore_name="store",
default_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
srid=4326,
),
)
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=detail))
app = build_test_app(module.router, "/api/v1")
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
client = TestClient(app)
response = client.get("/api/v1/project_info/", params={"network": "demo"})
assert response.status_code == 200
payload = response.json()
assert payload["code"] == "demo"
assert payload["geoserver"]["gs_base_url"] == "http://gs"
def test_open_project_returns_network_even_when_db_connection_fails(monkeypatch):
module = _load_project_module(monkeypatch)
called = []
monkeypatch.setattr(module, "open_project", lambda network: called.append(network))
async def failing_get_pg_db(network):
raise RuntimeError("db down")
monkeypatch.setattr(module, "get_pg_db", failing_get_pg_db)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post("/api/v1/openproject/", params={"network": "demo"})
assert response.status_code == 200
assert response.json() == "demo"
assert called == ["demo"]
def test_project_lock_lifecycle(monkeypatch):
module = _load_project_module(monkeypatch)
module.lockedPrjs.clear()
client = TestClient(build_test_app(module.router, "/api/v1"))
first_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
second_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
locked_by_me = client.get("/api/v1/isprojectlockedbyme/", params={"network": "demo"})
unlock = client.post("/api/v1/unlockproject/", params={"network": "demo"})
locked = client.get("/api/v1/isprojectlocked/", params={"network": "demo"})
assert first_lock.json() == 0
assert second_lock.json() == 1
assert locked_by_me.json() is True
assert unlock.json() is True
assert locked.json() is False
+154
View File
@@ -0,0 +1,154 @@
from typing import Any
from fastapi.testclient import TestClient
from tests.conftest import build_test_app, install_stub, load_module_from_path
class DummyChangeSet:
def __init__(self, operations=None):
if operations is None:
self.operations = []
elif isinstance(operations, dict):
self.operations = [operations]
else:
self.operations = operations
def _noop(*args, **kwargs):
return None
def _load_regions_module(monkeypatch):
install_stub(monkeypatch, "app.services", package=True)
install_stub(
monkeypatch,
"app.services.tjnetwork",
{
"Any": Any,
"ChangeSet": DummyChangeSet,
"add_district_metering_area": _noop,
"add_region": _noop,
"add_service_area": _noop,
"add_virtual_district": _noop,
"calculate_district_metering_area_for_network": lambda *args, **kwargs: [],
"calculate_district_metering_area_for_nodes": lambda *args, **kwargs: [],
"calculate_district_metering_area_for_region": lambda *args, **kwargs: [],
"calculate_service_area": lambda network: [],
"calculate_virtual_district": lambda *args, **kwargs: {},
"delete_district_metering_area": _noop,
"delete_region": _noop,
"delete_service_area": _noop,
"delete_virtual_district": _noop,
"generate_district_metering_area": _noop,
"generate_service_area": _noop,
"generate_sub_district_metering_area": _noop,
"generate_virtual_district": _noop,
"get_all_district_metering_area_ids": lambda network: [],
"get_all_district_metering_areas": lambda network: [],
"get_all_service_areas": lambda network: [],
"get_all_virtual_districts": lambda network: [],
"get_district_metering_area": lambda network, area_id: {},
"get_district_metering_area_schema": lambda network: {},
"get_region": lambda network, region_id: {},
"get_region_schema": lambda network: {},
"get_service_area": lambda network, area_id: {},
"get_service_area_schema": lambda network: {},
"get_virtual_district": lambda network, area_id: {},
"get_virtual_district_schema": lambda network: {},
"set_district_metering_area": _noop,
"set_region": _noop,
"set_service_area": _noop,
"set_virtual_district": _noop,
},
)
return load_module_from_path(
"tests_regions_endpoints_module",
"app/api/v1/endpoints/network/regions.py",
)
def test_removed_routes_are_absent_and_return_404(monkeypatch):
module = _load_regions_module(monkeypatch)
client = TestClient(build_test_app(module.router, "/api/v1"))
openapi = client.get("/openapi.json").json()
assert "/api/v1/calculateregion/" not in openapi["paths"]
assert "/api/v1/getallregions/" not in openapi["paths"]
assert "/api/v1/generateregion/" not in openapi["paths"]
assert "/api/v1/calculatedistrictmeteringarea/" not in openapi["paths"]
assert client.get("/api/v1/calculateregion/", params={"network": "demo", "time_index": 0}).status_code == 404
assert client.get("/api/v1/calculatedistrictmeteringarea/", params={"network": "demo"}).status_code == 404
def test_calculate_service_area_contract_uses_only_network(monkeypatch):
module = _load_regions_module(monkeypatch)
calls = []
monkeypatch.setattr(
module,
"calculate_service_area",
lambda network: calls.append(network) or [{"source-1": ["n1", "n2"]}],
)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.get(
"/api/v1/calculateservicearea/",
params={"network": "demo", "time_index": 5},
)
schema = client.get("/openapi.json").json()
assert response.status_code == 200
assert response.json() == [{"source-1": ["n1", "n2"]}]
assert calls == ["demo"]
parameter_names = [
item["name"]
for item in schema["paths"]["/api/v1/calculateservicearea/"]["get"]["parameters"]
]
assert parameter_names == ["network"]
def test_add_district_metering_area_converts_boundary_to_tuples(monkeypatch):
module = _load_regions_module(monkeypatch)
captured = {}
def fake_add(network, change_set):
captured["network"] = network
captured["boundary"] = change_set.operations[0]["boundary"]
return {"ok": True}
monkeypatch.setattr(module, "add_district_metering_area", fake_add)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/adddistrictmeteringarea/",
params={"network": "demo"},
json={"id": "dma-1", "boundary": [[1, 2], [3, 4], [1, 2]]},
)
assert response.status_code == 200
assert captured == {
"network": "demo",
"boundary": [(1, 2), (3, 4), (1, 2)],
}
def test_generate_virtual_district_reads_centers_from_body(monkeypatch):
module = _load_regions_module(monkeypatch)
captured = {}
def fake_generate(network, centers, inflate_delta):
captured["args"] = (network, centers, inflate_delta)
return {"generated": True}
monkeypatch.setattr(module, "generate_virtual_district", fake_generate)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/generatevirtualdistrict/",
params={"network": "demo", "inflate_delta": 0.75},
json={"centers": ["J1", "J2"]},
)
assert response.status_code == 200
assert captured["args"] == ("demo", ["J1", "J2"], 0.75)
+175
View File
@@ -0,0 +1,175 @@
from pathlib import Path
from fastapi.testclient import TestClient
from tests.conftest import build_test_app, install_stub, load_module_from_path
def _load_simulation_module(monkeypatch):
install_stub(monkeypatch, "app.services", package=True)
install_stub(
monkeypatch,
"app.services.simulation",
{"run_simulation": lambda **kwargs: None},
)
install_stub(monkeypatch, "app.services.globals", {})
install_stub(
monkeypatch,
"app.services.tjnetwork",
{
"run_project": lambda network: "report",
"run_project_return_dict": lambda network: {"output": {}, "report": "ok"},
"run_inp": lambda network: "inp-report",
"dump_output": lambda output: f"dump::{output}",
},
)
install_stub(monkeypatch, "app.algorithms", package=True)
install_stub(monkeypatch, "app.algorithms.simulation", package=True)
install_stub(
monkeypatch,
"app.algorithms.simulation.scenarios",
{
"burst_analysis": lambda *args, **kwargs: "burst",
"valve_close_analysis": lambda *args, **kwargs: "valve",
"flushing_analysis": lambda *args, **kwargs: "flush",
"contaminant_simulation": lambda *args, **kwargs: "contaminant",
"age_analysis": lambda *args, **kwargs: "age",
"pressure_regulation": lambda *args, **kwargs: "pressure",
},
)
install_stub(
monkeypatch,
"app.algorithms.sensor",
{
"pressure_sensor_placement_sensitivity": lambda *args, **kwargs: [],
"pressure_sensor_placement_kmeans": lambda *args, **kwargs: [],
},
)
install_stub(
monkeypatch,
"app.services.network_import",
{"network_update": lambda *args, **kwargs: "updated"},
)
install_stub(
monkeypatch,
"app.services.simulation_ops",
{
"project_management": lambda *args, **kwargs: "managed",
"scheduling_simulation": lambda *args, **kwargs: "scheduled",
"daily_scheduling_simulation": lambda *args, **kwargs: "daily",
},
)
install_stub(
monkeypatch,
"app.services.valve_isolation",
{"analyze_valve_isolation": lambda *args, **kwargs: {}},
)
return load_module_from_path(
"tests_simulation_endpoints_module",
"app/api/v1/endpoints/simulation.py",
)
def test_run_project_endpoint_returns_plain_text(monkeypatch):
module = _load_simulation_module(monkeypatch)
monkeypatch.setattr(module, "run_project", lambda network: f"report::{network}")
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.get("/api/v1/runproject/", params={"network": "demo"})
assert response.status_code == 200
assert response.text == "report::demo"
def test_scheduling_analysis_maps_request_body(monkeypatch):
module = _load_simulation_module(monkeypatch)
captured = {}
def fake_schedule(network, start_time, pump_control, tank_id, water_plant_output_id, time_delta):
captured["args"] = (
network,
start_time,
pump_control,
tank_id,
water_plant_output_id,
time_delta,
)
return "scheduled"
monkeypatch.setattr(module, "scheduling_simulation", fake_schedule)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/scheduling_analysis/",
json={
"network": "demo",
"start_time": "2025-01-01T08:00:00+08:00",
"pump_control": {"P1": [1, 0, 1]},
"tank_id": "T1",
"water_plant_output_id": "R1",
},
)
assert response.status_code == 200
assert response.json() == "scheduled"
assert captured["args"] == (
"demo",
"2025-01-01T08:00:00+08:00",
{"P1": [1, 0, 1]},
"T1",
"R1",
300,
)
def test_project_management_maps_named_arguments(monkeypatch):
module = _load_simulation_module(monkeypatch)
captured = {}
def fake_project_management(**kwargs):
captured.update(kwargs)
return "managed"
monkeypatch.setattr(module, "project_management", fake_project_management)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/project_management/",
json={
"network": "demo",
"start_time": "2025-01-01T08:00:00+08:00",
"pump_control": {"P1": [1]},
"tank_init_level": {"T1": 10.0},
"region_demand": {"R1": 20.0},
},
)
assert response.status_code == 200
assert response.json() == "managed"
assert captured == {
"prj_name": "demo",
"start_datetime": "2025-01-01T08:00:00+08:00",
"pump_control": {"P1": [1]},
"tank_initial_level_control": {"T1": 10.0},
"region_demand_control": {"R1": 20.0},
}
def test_network_update_surfaces_service_error(monkeypatch, tmp_path):
module = _load_simulation_module(monkeypatch)
monkeypatch.chdir(tmp_path)
def boom(_path):
raise RuntimeError("write failed")
monkeypatch.setattr(module, "network_update", boom)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/network_update/",
files={"file": ("update.txt", b"payload")},
)
assert response.status_code == 500
assert "数据库操作失败: write failed" in response.json()["detail"]
assert list(Path(tmp_path).glob("network_update_*"))
@@ -0,0 +1,95 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from fastapi.testclient import TestClient
from app.api.v1.endpoints import user_management as user_management_endpoint
from app.auth.dependencies import get_current_active_user, get_user_repository
from app.auth.permissions import get_current_admin
from app.domain.models.role import UserRole
from tests.conftest import build_test_app, make_user
def _build_client(repo, *, current_user=None, admin_user=None) -> TestClient:
app = build_test_app(user_management_endpoint.router, "/users")
app.dependency_overrides[get_user_repository] = lambda: repo
if current_user is not None:
app.dependency_overrides[get_current_active_user] = lambda: current_user
if admin_user is not None:
app.dependency_overrides[get_current_admin] = lambda: admin_user
return TestClient(app)
def test_list_users_requires_admin_role():
repo = SimpleNamespace(
get_all_users=AsyncMock(
return_value=[
make_user(id=1, username="admin", role=UserRole.ADMIN),
make_user(id=2, username="user2"),
]
)
)
client = _build_client(
repo,
current_user=make_user(id=1, role=UserRole.ADMIN),
)
response = client.get("/users/", params={"skip": 5, "limit": 2})
assert response.status_code == 200
assert len(response.json()) == 2
repo.get_all_users.assert_awaited_once_with(skip=5, limit=2)
def test_get_user_rejects_non_owner_non_admin():
repo = SimpleNamespace(get_user_by_id=AsyncMock())
client = _build_client(repo, current_user=make_user(id=2, role=UserRole.USER))
response = client.get("/users/3")
assert response.status_code == 403
assert response.json()["detail"] == "You don't have permission to view this user"
repo.get_user_by_id.assert_not_awaited()
def test_update_user_blocks_role_change_for_non_admin():
repo = SimpleNamespace(
get_user_by_id=AsyncMock(return_value=make_user(id=1)),
update_user=AsyncMock(),
)
client = _build_client(repo, current_user=make_user(id=1, role=UserRole.USER))
response = client.put("/users/1", json={"role": "ADMIN"})
assert response.status_code == 403
assert response.json()["detail"] == "Only admins can change user roles"
repo.update_user.assert_not_awaited()
def test_delete_user_blocks_self_delete_for_admin():
admin_user = make_user(id=1, role=UserRole.ADMIN, is_superuser=True)
repo = SimpleNamespace(delete_user=AsyncMock())
client = _build_client(repo, admin_user=admin_user)
response = client.delete("/users/1")
assert response.status_code == 400
assert response.json()["detail"] == "You cannot delete your own account"
repo.delete_user.assert_not_awaited()
def test_activate_user_updates_active_flag():
repo = SimpleNamespace(
update_user=AsyncMock(return_value=make_user(id=2, is_active=True)),
)
client = _build_client(
repo,
admin_user=make_user(id=1, role=UserRole.ADMIN, is_superuser=True),
)
response = client.post("/users/2/activate")
assert response.status_code == 200
assert response.json()["is_active"] is True
user_update = repo.update_user.await_args.args[1]
assert user_update.is_active is True