新增 API 测试用例,修复失效接口问题

This commit is contained in:
2026-05-21 15:32:12 +08:00
parent 751950e5b5
commit 2317f4d527
15 changed files with 1486 additions and 96 deletions
+4 -72
View File
@@ -7,11 +7,9 @@ from app.services.tjnetwork import (
add_region,
add_service_area,
add_virtual_district,
# calculate_district_metering_area,
calculate_district_metering_area_for_network,
calculate_district_metering_area_for_nodes,
calculate_district_metering_area_for_region,
# calculate_region,
calculate_service_area,
calculate_virtual_district,
delete_district_metering_area,
@@ -19,13 +17,11 @@ from app.services.tjnetwork import (
delete_service_area,
delete_virtual_district,
generate_district_metering_area,
# generate_region,
generate_service_area,
generate_sub_district_metering_area,
generate_virtual_district,
get_all_district_metering_area_ids,
get_all_district_metering_areas,
# get_all_regions,
get_all_service_areas,
get_all_virtual_districts,
get_district_metering_area,
@@ -48,18 +44,6 @@ router = APIRouter()
# region 32
############################################################
@router.get(
"/calculateregion/",
summary="计算区域",
description="计算指定水网在指定时间步长的区域分区"
)
async def fastapi_calculate_region(
network: str = Query(..., description="管网名称(或数据库名称)"),
time_index: int = Query(..., description="时间步长索引", ge=0)
) -> dict[str, Any]:
"""计算区域分区。"""
return calculate_region(network, time_index)
@router.get(
"/getregionschema/",
summary="获取区域属性架构",
@@ -125,62 +109,11 @@ async def fastapi_delete_region(
props = await req.json()
return delete_region(network, ChangeSet(props))
@router.get(
"/getallregions/",
summary="获取所有区域",
description="获取指定水网中的所有区域信息"
)
async def fastapi_get_all_regions(
network: str = Query(..., description="管网名称(或数据库名称)")
) -> list[dict[str, Any]]:
"""获取所有区域的信息列表。"""
return get_all_regions(network)
@router.post(
"/generateregion/",
response_model=None,
summary="生成区域分区",
description="根据参数自动生成水网的区域分区"
)
async def fastapi_generate_region(
network: str = Query(..., description="管网名称(或数据库名称)"),
inflate_delta: float = Query(..., description="膨胀参数")
) -> ChangeSet:
"""生成区域分区。"""
return generate_region(network, inflate_delta)
############################################################
# district_metering_area 33
############################################################
@router.get(
"/calculatedistrictmeteringarea/",
summary="计算DMA分区",
description="计算指定节点集的区域计量(DMA)分区方案"
)
async def fastapi_calculate_district_metering_area(
network: str = Query(..., description="管网名称(或数据库名称)"),
req: Request = None
) -> list[list[str]]:
"""
计算DMA分区。
请求体格式:
{
"nodes": 节点ID列表(list[str]),
"part_count": 分区数量(int),
"part_type": 分区类型(int)
}
"""
props = await req.json()
nodes = props["nodes"]
part_count = props["part_count"]
part_type = props["part_type"]
return calculate_district_metering_area(
network, nodes, part_count, part_type
)
@router.get(
"/calculatedistrictmeteringareaforregion/",
summary="计算区域内DMA分区",
@@ -368,14 +301,13 @@ async def fastapi_generate_sub_district_metering_area(
@router.get(
"/calculateservicearea/",
summary="计算服务区",
description="计算指定水网在指定时间步长的服务区分区"
description="计算指定水网的服务区分区,返回全部时间步结果"
)
async def fastapi_calculate_service_area(
network: str = Query(..., description="管网名称(或数据库名称)"),
time_index: int = Query(..., description="时间步长索引", ge=0)
) -> dict[str, Any]:
"""计算服务区分区。"""
return calculate_service_area(network, time_index)
) -> list[dict[str, list[str]]]:
"""计算服务区分区,返回全部时间步结果。"""
return calculate_service_area(network)
@router.get(
"/getserviceareaschema/",
+91
View File
@@ -0,0 +1,91 @@
from unittest.mock import AsyncMock
from fastapi.testclient import TestClient
from app.api.v1.endpoints import audit as audit_endpoint
from app.auth.metadata_dependencies import (
get_current_metadata_admin,
get_current_metadata_user,
)
from tests.conftest import build_test_app, make_audit_log
def _build_client(repo, *, metadata_admin=None, metadata_user=None) -> TestClient:
app = build_test_app(audit_endpoint.router, "/audit")
app.dependency_overrides[audit_endpoint.get_audit_repository] = lambda: repo
if metadata_admin is not None:
app.dependency_overrides[get_current_metadata_admin] = lambda: metadata_admin
if metadata_user is not None:
app.dependency_overrides[get_current_metadata_user] = lambda: metadata_user
return TestClient(app)
def test_get_audit_logs_passes_filters():
repo = type(
"Repo",
(),
{
"get_logs": AsyncMock(return_value=[make_audit_log(action="LOGIN")]),
"get_log_count": AsyncMock(),
},
)()
client = _build_client(repo, metadata_admin=object())
response = client.get(
"/audit/logs",
params={
"action": "LOGIN",
"resource_type": "user",
"skip": 2,
"limit": 5,
},
)
assert response.status_code == 200
assert response.json()[0]["action"] == "LOGIN"
repo.get_logs.assert_awaited_once()
kwargs = repo.get_logs.await_args.kwargs
assert kwargs["action"] == "LOGIN"
assert kwargs["resource_type"] == "user"
assert kwargs["skip"] == 2
assert kwargs["limit"] == 5
def test_get_audit_logs_count_returns_count_payload():
repo = type(
"Repo",
(),
{
"get_logs": AsyncMock(),
"get_log_count": AsyncMock(return_value=7),
},
)()
client = _build_client(repo, metadata_admin=object())
response = client.get("/audit/logs/count", params={"action": "DELETE_USER"})
assert response.status_code == 200
assert response.json() == {"count": 7}
repo.get_log_count.assert_awaited_once()
assert repo.get_log_count.await_args.kwargs["action"] == "DELETE_USER"
def test_get_my_audit_logs_forces_current_user_id():
current_user = type("User", (), {"id": make_audit_log().user_id})()
repo = type(
"Repo",
(),
{
"get_logs": AsyncMock(return_value=[make_audit_log(user_id=current_user.id)]),
"get_log_count": AsyncMock(),
},
)()
client = _build_client(repo, metadata_user=current_user)
response = client.get("/audit/logs/my", params={"limit": 3})
assert response.status_code == 200
repo.get_logs.assert_awaited_once()
kwargs = repo.get_logs.await_args.kwargs
assert kwargs["user_id"] == current_user.id
assert kwargs["limit"] == 3
+139
View File
@@ -0,0 +1,139 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from fastapi.testclient import TestClient
from app.api.v1.endpoints import auth as auth_endpoint
from app.auth.dependencies import get_current_active_user, get_user_repository
from app.core.security import create_access_token, create_refresh_token, get_password_hash
from tests.conftest import build_test_app, make_user
def _build_client(repo, current_user=None) -> TestClient:
app = build_test_app(auth_endpoint.router, "/api/v1/auth")
app.dependency_overrides[get_user_repository] = lambda: repo
if current_user is not None:
app.dependency_overrides[get_current_active_user] = lambda: current_user
return TestClient(app)
def test_register_success():
repo = SimpleNamespace(
user_exists=AsyncMock(side_effect=[False, False]),
create_user=AsyncMock(return_value=make_user()),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/register",
json={
"username": "tester",
"email": "tester@example.com",
"password": "secret123",
},
)
assert response.status_code == 201
assert response.json()["username"] == "tester"
def test_register_rejects_duplicate_username():
repo = SimpleNamespace(
user_exists=AsyncMock(side_effect=[True]),
create_user=AsyncMock(),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/register",
json={
"username": "tester",
"email": "tester@example.com",
"password": "secret123",
},
)
assert response.status_code == 400
assert response.json()["detail"] == "Username already registered"
repo.create_user.assert_not_awaited()
def test_login_supports_email_lookup():
hashed_password = get_password_hash("secret123")
repo = SimpleNamespace(
get_user_by_username=AsyncMock(return_value=None),
get_user_by_email=AsyncMock(
return_value=make_user(
email="tester@example.com",
hashed_password=hashed_password,
)
),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/login",
data={"username": "tester@example.com", "password": "secret123"},
)
assert response.status_code == 200
assert response.json()["token_type"] == "bearer"
repo.get_user_by_email.assert_awaited_once_with("tester@example.com")
def test_login_simple_uses_query_params():
hashed_password = get_password_hash("secret123")
repo = SimpleNamespace(
get_user_by_username=AsyncMock(
return_value=make_user(hashed_password=hashed_password)
),
get_user_by_email=AsyncMock(),
)
client = _build_client(repo)
response = client.post(
"/api/v1/auth/login/simple",
params={"username": "tester", "password": "secret123"},
)
assert response.status_code == 200
assert response.json()["token_type"] == "bearer"
def test_me_returns_current_user_info():
client = _build_client(SimpleNamespace(), current_user=make_user(username="alice"))
response = client.get("/api/v1/auth/me")
assert response.status_code == 200
assert response.json()["username"] == "alice"
def test_refresh_rejects_access_token():
repo = SimpleNamespace(get_user_by_username=AsyncMock())
client = _build_client(repo)
response = client.post(
"/api/v1/auth/refresh",
params={"refresh_token": create_access_token("tester")},
)
assert response.status_code == 401
def test_refresh_success_returns_new_access_token():
repo = SimpleNamespace(
get_user_by_username=AsyncMock(return_value=make_user()),
)
client = _build_client(repo)
refresh_token = create_refresh_token("tester")
response = client.post(
"/api/v1/auth/refresh",
params={"refresh_token": refresh_token},
)
assert response.status_code == 200
payload = response.json()
assert payload["refresh_token"] == refresh_token
assert payload["token_type"] == "bearer"
+152
View File
@@ -0,0 +1,152 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from uuid import uuid4
from fastapi.testclient import TestClient
from tests.conftest import build_test_app, install_stub, load_module_from_path
class DummyChangeSet:
def __init__(self, operations=None):
if operations is None:
self.operations = []
elif isinstance(operations, dict):
self.operations = [operations]
else:
self.operations = operations
def _load_project_module(monkeypatch):
install_stub(monkeypatch, "app.services", package=True)
install_stub(
monkeypatch,
"app.services.project_info",
{},
)
install_stub(
monkeypatch,
"app.services.tjnetwork",
{
"ChangeSet": DummyChangeSet,
"list_project": lambda: ["demo"],
"have_project": lambda network: network == "demo",
"create_project": lambda network: None,
"delete_project": lambda network: None,
"is_project_open": lambda network: False,
"open_project": lambda network: None,
"close_project": lambda network: None,
"copy_project": lambda source, target: None,
"import_inp": lambda network, cs: {"ok": True},
"export_inp": lambda network, version: DummyChangeSet({"kind": "export"}),
"read_inp": lambda network, inp: True,
"dump_inp": lambda network, inp: True,
"get_all_vertices": lambda network: [],
"get_all_scada_elements": lambda network: [],
"get_all_district_metering_areas": lambda network: [],
"get_all_service_areas": lambda network: [],
"get_all_virtual_districts": lambda network: [],
"get_extension_data": lambda network, key: None,
"convert_inp_v3_to_v2": lambda inp: DummyChangeSet({"inp": inp}),
},
)
install_stub(
monkeypatch,
"app.auth.project_dependencies",
{"get_metadata_repository": lambda: None},
)
install_stub(
monkeypatch,
"app.infra.db.postgresql.database",
{"get_database_instance": lambda network: None},
)
install_stub(
monkeypatch,
"app.infra.db.timescaledb.database",
{"get_database_instance": lambda network: None},
)
return load_module_from_path(
"tests_project_endpoints_module",
"app/api/v1/endpoints/project.py",
)
def test_project_info_returns_404_when_missing(monkeypatch):
module = _load_project_module(monkeypatch)
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=None))
app = build_test_app(module.router, "/api/v1")
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
client = TestClient(app)
response = client.get("/api/v1/project_info/", params={"network": "missing"})
assert response.status_code == 404
assert response.json()["detail"] == "Project missing not found"
def test_project_info_returns_geoserver_payload(monkeypatch):
module = _load_project_module(monkeypatch)
detail = SimpleNamespace(
project_id=uuid4(),
name="Demo Project",
code="demo",
description="desc",
gs_workspace="ws",
map_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
status="active",
geoserver=SimpleNamespace(
gs_base_url="http://gs",
gs_admin_user="admin",
gs_datastore_name="store",
default_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
srid=4326,
),
)
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=detail))
app = build_test_app(module.router, "/api/v1")
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
client = TestClient(app)
response = client.get("/api/v1/project_info/", params={"network": "demo"})
assert response.status_code == 200
payload = response.json()
assert payload["code"] == "demo"
assert payload["geoserver"]["gs_base_url"] == "http://gs"
def test_open_project_returns_network_even_when_db_connection_fails(monkeypatch):
module = _load_project_module(monkeypatch)
called = []
monkeypatch.setattr(module, "open_project", lambda network: called.append(network))
async def failing_get_pg_db(network):
raise RuntimeError("db down")
monkeypatch.setattr(module, "get_pg_db", failing_get_pg_db)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post("/api/v1/openproject/", params={"network": "demo"})
assert response.status_code == 200
assert response.json() == "demo"
assert called == ["demo"]
def test_project_lock_lifecycle(monkeypatch):
module = _load_project_module(monkeypatch)
module.lockedPrjs.clear()
client = TestClient(build_test_app(module.router, "/api/v1"))
first_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
second_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
locked_by_me = client.get("/api/v1/isprojectlockedbyme/", params={"network": "demo"})
unlock = client.post("/api/v1/unlockproject/", params={"network": "demo"})
locked = client.get("/api/v1/isprojectlocked/", params={"network": "demo"})
assert first_lock.json() == 0
assert second_lock.json() == 1
assert locked_by_me.json() is True
assert unlock.json() is True
assert locked.json() is False
+154
View File
@@ -0,0 +1,154 @@
from typing import Any
from fastapi.testclient import TestClient
from tests.conftest import build_test_app, install_stub, load_module_from_path
class DummyChangeSet:
def __init__(self, operations=None):
if operations is None:
self.operations = []
elif isinstance(operations, dict):
self.operations = [operations]
else:
self.operations = operations
def _noop(*args, **kwargs):
return None
def _load_regions_module(monkeypatch):
install_stub(monkeypatch, "app.services", package=True)
install_stub(
monkeypatch,
"app.services.tjnetwork",
{
"Any": Any,
"ChangeSet": DummyChangeSet,
"add_district_metering_area": _noop,
"add_region": _noop,
"add_service_area": _noop,
"add_virtual_district": _noop,
"calculate_district_metering_area_for_network": lambda *args, **kwargs: [],
"calculate_district_metering_area_for_nodes": lambda *args, **kwargs: [],
"calculate_district_metering_area_for_region": lambda *args, **kwargs: [],
"calculate_service_area": lambda network: [],
"calculate_virtual_district": lambda *args, **kwargs: {},
"delete_district_metering_area": _noop,
"delete_region": _noop,
"delete_service_area": _noop,
"delete_virtual_district": _noop,
"generate_district_metering_area": _noop,
"generate_service_area": _noop,
"generate_sub_district_metering_area": _noop,
"generate_virtual_district": _noop,
"get_all_district_metering_area_ids": lambda network: [],
"get_all_district_metering_areas": lambda network: [],
"get_all_service_areas": lambda network: [],
"get_all_virtual_districts": lambda network: [],
"get_district_metering_area": lambda network, area_id: {},
"get_district_metering_area_schema": lambda network: {},
"get_region": lambda network, region_id: {},
"get_region_schema": lambda network: {},
"get_service_area": lambda network, area_id: {},
"get_service_area_schema": lambda network: {},
"get_virtual_district": lambda network, area_id: {},
"get_virtual_district_schema": lambda network: {},
"set_district_metering_area": _noop,
"set_region": _noop,
"set_service_area": _noop,
"set_virtual_district": _noop,
},
)
return load_module_from_path(
"tests_regions_endpoints_module",
"app/api/v1/endpoints/network/regions.py",
)
def test_removed_routes_are_absent_and_return_404(monkeypatch):
module = _load_regions_module(monkeypatch)
client = TestClient(build_test_app(module.router, "/api/v1"))
openapi = client.get("/openapi.json").json()
assert "/api/v1/calculateregion/" not in openapi["paths"]
assert "/api/v1/getallregions/" not in openapi["paths"]
assert "/api/v1/generateregion/" not in openapi["paths"]
assert "/api/v1/calculatedistrictmeteringarea/" not in openapi["paths"]
assert client.get("/api/v1/calculateregion/", params={"network": "demo", "time_index": 0}).status_code == 404
assert client.get("/api/v1/calculatedistrictmeteringarea/", params={"network": "demo"}).status_code == 404
def test_calculate_service_area_contract_uses_only_network(monkeypatch):
module = _load_regions_module(monkeypatch)
calls = []
monkeypatch.setattr(
module,
"calculate_service_area",
lambda network: calls.append(network) or [{"source-1": ["n1", "n2"]}],
)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.get(
"/api/v1/calculateservicearea/",
params={"network": "demo", "time_index": 5},
)
schema = client.get("/openapi.json").json()
assert response.status_code == 200
assert response.json() == [{"source-1": ["n1", "n2"]}]
assert calls == ["demo"]
parameter_names = [
item["name"]
for item in schema["paths"]["/api/v1/calculateservicearea/"]["get"]["parameters"]
]
assert parameter_names == ["network"]
def test_add_district_metering_area_converts_boundary_to_tuples(monkeypatch):
module = _load_regions_module(monkeypatch)
captured = {}
def fake_add(network, change_set):
captured["network"] = network
captured["boundary"] = change_set.operations[0]["boundary"]
return {"ok": True}
monkeypatch.setattr(module, "add_district_metering_area", fake_add)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/adddistrictmeteringarea/",
params={"network": "demo"},
json={"id": "dma-1", "boundary": [[1, 2], [3, 4], [1, 2]]},
)
assert response.status_code == 200
assert captured == {
"network": "demo",
"boundary": [(1, 2), (3, 4), (1, 2)],
}
def test_generate_virtual_district_reads_centers_from_body(monkeypatch):
module = _load_regions_module(monkeypatch)
captured = {}
def fake_generate(network, centers, inflate_delta):
captured["args"] = (network, centers, inflate_delta)
return {"generated": True}
monkeypatch.setattr(module, "generate_virtual_district", fake_generate)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/generatevirtualdistrict/",
params={"network": "demo", "inflate_delta": 0.75},
json={"centers": ["J1", "J2"]},
)
assert response.status_code == 200
assert captured["args"] == ("demo", ["J1", "J2"], 0.75)
+175
View File
@@ -0,0 +1,175 @@
from pathlib import Path
from fastapi.testclient import TestClient
from tests.conftest import build_test_app, install_stub, load_module_from_path
def _load_simulation_module(monkeypatch):
install_stub(monkeypatch, "app.services", package=True)
install_stub(
monkeypatch,
"app.services.simulation",
{"run_simulation": lambda **kwargs: None},
)
install_stub(monkeypatch, "app.services.globals", {})
install_stub(
monkeypatch,
"app.services.tjnetwork",
{
"run_project": lambda network: "report",
"run_project_return_dict": lambda network: {"output": {}, "report": "ok"},
"run_inp": lambda network: "inp-report",
"dump_output": lambda output: f"dump::{output}",
},
)
install_stub(monkeypatch, "app.algorithms", package=True)
install_stub(monkeypatch, "app.algorithms.simulation", package=True)
install_stub(
monkeypatch,
"app.algorithms.simulation.scenarios",
{
"burst_analysis": lambda *args, **kwargs: "burst",
"valve_close_analysis": lambda *args, **kwargs: "valve",
"flushing_analysis": lambda *args, **kwargs: "flush",
"contaminant_simulation": lambda *args, **kwargs: "contaminant",
"age_analysis": lambda *args, **kwargs: "age",
"pressure_regulation": lambda *args, **kwargs: "pressure",
},
)
install_stub(
monkeypatch,
"app.algorithms.sensor",
{
"pressure_sensor_placement_sensitivity": lambda *args, **kwargs: [],
"pressure_sensor_placement_kmeans": lambda *args, **kwargs: [],
},
)
install_stub(
monkeypatch,
"app.services.network_import",
{"network_update": lambda *args, **kwargs: "updated"},
)
install_stub(
monkeypatch,
"app.services.simulation_ops",
{
"project_management": lambda *args, **kwargs: "managed",
"scheduling_simulation": lambda *args, **kwargs: "scheduled",
"daily_scheduling_simulation": lambda *args, **kwargs: "daily",
},
)
install_stub(
monkeypatch,
"app.services.valve_isolation",
{"analyze_valve_isolation": lambda *args, **kwargs: {}},
)
return load_module_from_path(
"tests_simulation_endpoints_module",
"app/api/v1/endpoints/simulation.py",
)
def test_run_project_endpoint_returns_plain_text(monkeypatch):
module = _load_simulation_module(monkeypatch)
monkeypatch.setattr(module, "run_project", lambda network: f"report::{network}")
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.get("/api/v1/runproject/", params={"network": "demo"})
assert response.status_code == 200
assert response.text == "report::demo"
def test_scheduling_analysis_maps_request_body(monkeypatch):
module = _load_simulation_module(monkeypatch)
captured = {}
def fake_schedule(network, start_time, pump_control, tank_id, water_plant_output_id, time_delta):
captured["args"] = (
network,
start_time,
pump_control,
tank_id,
water_plant_output_id,
time_delta,
)
return "scheduled"
monkeypatch.setattr(module, "scheduling_simulation", fake_schedule)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/scheduling_analysis/",
json={
"network": "demo",
"start_time": "2025-01-01T08:00:00+08:00",
"pump_control": {"P1": [1, 0, 1]},
"tank_id": "T1",
"water_plant_output_id": "R1",
},
)
assert response.status_code == 200
assert response.json() == "scheduled"
assert captured["args"] == (
"demo",
"2025-01-01T08:00:00+08:00",
{"P1": [1, 0, 1]},
"T1",
"R1",
300,
)
def test_project_management_maps_named_arguments(monkeypatch):
module = _load_simulation_module(monkeypatch)
captured = {}
def fake_project_management(**kwargs):
captured.update(kwargs)
return "managed"
monkeypatch.setattr(module, "project_management", fake_project_management)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/project_management/",
json={
"network": "demo",
"start_time": "2025-01-01T08:00:00+08:00",
"pump_control": {"P1": [1]},
"tank_init_level": {"T1": 10.0},
"region_demand": {"R1": 20.0},
},
)
assert response.status_code == 200
assert response.json() == "managed"
assert captured == {
"prj_name": "demo",
"start_datetime": "2025-01-01T08:00:00+08:00",
"pump_control": {"P1": [1]},
"tank_initial_level_control": {"T1": 10.0},
"region_demand_control": {"R1": 20.0},
}
def test_network_update_surfaces_service_error(monkeypatch, tmp_path):
module = _load_simulation_module(monkeypatch)
monkeypatch.chdir(tmp_path)
def boom(_path):
raise RuntimeError("write failed")
monkeypatch.setattr(module, "network_update", boom)
client = TestClient(build_test_app(module.router, "/api/v1"))
response = client.post(
"/api/v1/network_update/",
files={"file": ("update.txt", b"payload")},
)
assert response.status_code == 500
assert "数据库操作失败: write failed" in response.json()["detail"]
assert list(Path(tmp_path).glob("network_update_*"))
@@ -0,0 +1,95 @@
from types import SimpleNamespace
from unittest.mock import AsyncMock
from fastapi.testclient import TestClient
from app.api.v1.endpoints import user_management as user_management_endpoint
from app.auth.dependencies import get_current_active_user, get_user_repository
from app.auth.permissions import get_current_admin
from app.domain.models.role import UserRole
from tests.conftest import build_test_app, make_user
def _build_client(repo, *, current_user=None, admin_user=None) -> TestClient:
app = build_test_app(user_management_endpoint.router, "/users")
app.dependency_overrides[get_user_repository] = lambda: repo
if current_user is not None:
app.dependency_overrides[get_current_active_user] = lambda: current_user
if admin_user is not None:
app.dependency_overrides[get_current_admin] = lambda: admin_user
return TestClient(app)
def test_list_users_requires_admin_role():
repo = SimpleNamespace(
get_all_users=AsyncMock(
return_value=[
make_user(id=1, username="admin", role=UserRole.ADMIN),
make_user(id=2, username="user2"),
]
)
)
client = _build_client(
repo,
current_user=make_user(id=1, role=UserRole.ADMIN),
)
response = client.get("/users/", params={"skip": 5, "limit": 2})
assert response.status_code == 200
assert len(response.json()) == 2
repo.get_all_users.assert_awaited_once_with(skip=5, limit=2)
def test_get_user_rejects_non_owner_non_admin():
repo = SimpleNamespace(get_user_by_id=AsyncMock())
client = _build_client(repo, current_user=make_user(id=2, role=UserRole.USER))
response = client.get("/users/3")
assert response.status_code == 403
assert response.json()["detail"] == "You don't have permission to view this user"
repo.get_user_by_id.assert_not_awaited()
def test_update_user_blocks_role_change_for_non_admin():
repo = SimpleNamespace(
get_user_by_id=AsyncMock(return_value=make_user(id=1)),
update_user=AsyncMock(),
)
client = _build_client(repo, current_user=make_user(id=1, role=UserRole.USER))
response = client.put("/users/1", json={"role": "ADMIN"})
assert response.status_code == 403
assert response.json()["detail"] == "Only admins can change user roles"
repo.update_user.assert_not_awaited()
def test_delete_user_blocks_self_delete_for_admin():
admin_user = make_user(id=1, role=UserRole.ADMIN, is_superuser=True)
repo = SimpleNamespace(delete_user=AsyncMock())
client = _build_client(repo, admin_user=admin_user)
response = client.delete("/users/1")
assert response.status_code == 400
assert response.json()["detail"] == "You cannot delete your own account"
repo.delete_user.assert_not_awaited()
def test_activate_user_updates_active_flag():
repo = SimpleNamespace(
update_user=AsyncMock(return_value=make_user(id=2, is_active=True)),
)
client = _build_client(
repo,
admin_user=make_user(id=1, role=UserRole.ADMIN, is_superuser=True),
)
response = client.post("/users/2/activate")
assert response.status_code == 200
assert response.json()["is_active"] is True
user_update = repo.update_user.await_args.args[1]
assert user_update.is_active is True
+36
View File
@@ -0,0 +1,36 @@
from jose import jwt
from app.core.config import settings
from app.core.security import (
create_access_token,
create_refresh_token,
get_password_hash,
verify_password,
)
def test_password_hash_roundtrip():
hashed = get_password_hash("secret123")
assert hashed != "secret123"
assert verify_password("secret123", hashed) is True
assert verify_password("wrong", hashed) is False
def test_create_access_token_sets_access_type():
token = create_access_token("alice")
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
assert payload["sub"] == "alice"
assert payload["type"] == "access"
assert "exp" in payload
assert "iat" in payload
def test_create_refresh_token_sets_refresh_type():
token = create_refresh_token("alice")
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
assert payload["sub"] == "alice"
assert payload["type"] == "refresh"
assert "exp" in payload
assert "iat" in payload
+189 -6
View File
@@ -1,14 +1,197 @@
import pytest
import sys
import importlib
import importlib.util
import os
import sys
import types
from datetime import datetime, timezone
from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
import pytest
from fastapi import FastAPI
# 自动添加项目根目录到路径(处理项目结构)
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
PROJECT_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PROJECT_ROOT))
def run_this_test(test_file):
"""自定义函数:运行单个测试文件(类似pytest)"""
# 提取测试文件名(无扩展名)
test_name = os.path.splitext(os.path.basename(test_file))[0]
# 使用pytest运行(自动处理导入)
pytest.main([test_file, "-v"])
def build_test_app(router, prefix: str = "") -> FastAPI:
app = FastAPI()
app.include_router(router, prefix=prefix)
return app
def load_module_from_path(module_name: str, relative_path: str):
module_path = PROJECT_ROOT / relative_path
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(module)
return module
def install_stub(monkeypatch, name: str, attrs: dict | None = None, package: bool = False):
module = types.ModuleType(name)
if package:
module.__path__ = []
if attrs:
for key, value in attrs.items():
setattr(module, key, value)
monkeypatch.setitem(sys.modules, name, module)
parent_name, _, child_name = name.rpartition(".")
if parent_name:
parent = sys.modules.get(parent_name)
if parent is None:
try:
parent = importlib.import_module(parent_name)
except Exception:
parent = types.ModuleType(parent_name)
parent.__path__ = []
monkeypatch.setitem(sys.modules, parent_name, parent)
setattr(parent, child_name, module)
return module
class FakeCursor:
def __init__(
self,
*,
fetchone_results=None,
fetchall_results=None,
rowcount: int = 0,
rowcounts=None,
):
self._fetchone_results = list(fetchone_results or [])
self._fetchall_results = list(fetchall_results or [])
self._rowcounts = list(rowcounts or [])
self.rowcount = rowcount
self.executed: list[tuple[str, dict | tuple | None]] = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def execute(self, query, params=None):
self.executed.append((str(query), params))
if self._rowcounts:
self.rowcount = self._rowcounts.pop(0)
async def fetchone(self):
if self._fetchone_results:
return self._fetchone_results.pop(0)
return None
async def fetchall(self):
if self._fetchall_results:
return self._fetchall_results.pop(0)
return []
class FakeConnection:
def __init__(self, cursor: FakeCursor):
self.cursor_instance = cursor
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
def cursor(self):
return self.cursor_instance
class FakeDB:
def __init__(self, cursor: FakeCursor):
self.connection = FakeConnection(cursor)
def get_connection(self):
return self.connection
class FakeExecuteResult:
def __init__(self, *, rows=None, scalar_value=None):
self._rows = list(rows or [])
self._scalar_value = scalar_value
def scalars(self):
return self
def all(self):
return self._rows
def scalar(self):
return self._scalar_value
class FakeAsyncSession:
def __init__(self, execute_results=None):
self._execute_results = list(execute_results or [])
self.executed = []
self.added = []
self.commit_count = 0
self.refreshed = []
def add(self, obj):
self.added.append(obj)
async def execute(self, stmt):
self.executed.append(stmt)
if self._execute_results:
return self._execute_results.pop(0)
return FakeExecuteResult()
async def commit(self):
self.commit_count += 1
async def refresh(self, obj):
self.refreshed.append(obj)
def make_user(**overrides):
from app.domain.models.role import UserRole
from app.domain.schemas.user import UserInDB
data = {
"id": 1,
"username": "tester",
"email": "tester@example.com",
"hashed_password": "hashed-password",
"role": UserRole.USER,
"is_active": True,
"is_superuser": False,
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
"updated_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
data.update(overrides)
return UserInDB(**data)
def make_audit_log(**overrides):
data = {
"id": uuid4(),
"user_id": uuid4(),
"project_id": uuid4(),
"action": "LOGIN",
"resource_type": "user",
"resource_id": "1",
"ip_address": "127.0.0.1",
"request_method": "GET",
"request_path": "/audit/logs",
"request_data": {"ok": True},
"response_status": 200,
"timestamp": datetime(2025, 1, 1, tzinfo=timezone.utc),
}
data.update(overrides)
return SimpleNamespace(**data)
+79
View File
@@ -0,0 +1,79 @@
import asyncio
from datetime import datetime, timezone
from uuid import uuid4
from app.infra.db.metadb.repositories.audit_repository import AuditRepository
from tests.conftest import FakeAsyncSession, FakeExecuteResult, make_audit_log
def test_create_log_adds_commits_and_refreshes(monkeypatch):
class FakeAuditLog:
def __init__(self, **kwargs):
self.id = uuid4()
for key, value in kwargs.items():
setattr(self, key, value)
session = FakeAsyncSession()
repo = AuditRepository(session)
monkeypatch.setattr(
"app.infra.db.metadb.repositories.audit_repository.models.AuditLog",
FakeAuditLog,
)
result = asyncio.run(
repo.create_log(
action="LOGIN",
request_method="POST",
request_path="/auth/login",
response_status=200,
)
)
assert result.action == "LOGIN"
assert result.request_method == "POST"
assert session.commit_count == 1
assert len(session.added) == 1
assert len(session.refreshed) == 1
def test_get_logs_builds_filtered_query_and_returns_models():
log = make_audit_log(action="UPDATE_USER", resource_type="user")
session = FakeAsyncSession(
execute_results=[FakeExecuteResult(rows=[log])],
)
repo = AuditRepository(session)
user_id = uuid4()
project_id = uuid4()
start_time = datetime(2025, 1, 1, tzinfo=timezone.utc)
results = asyncio.run(
repo.get_logs(
user_id=user_id,
project_id=project_id,
action="UPDATE_USER",
resource_type="user",
start_time=start_time,
skip=5,
limit=10,
)
)
assert len(results) == 1
assert results[0].action == "UPDATE_USER"
stmt = session.executed[0]
assert len(stmt._where_criteria) == 5
assert stmt._offset == 5
assert stmt._limit == 10
def test_get_log_count_returns_zero_when_scalar_none():
session = FakeAsyncSession(
execute_results=[FakeExecuteResult(scalar_value=None)],
)
repo = AuditRepository(session)
result = asyncio.run(repo.get_log_count(action="DELETE_USER"))
assert result == 0
stmt = session.executed[0]
assert len(stmt._where_criteria) == 1
+97
View File
@@ -0,0 +1,97 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from fastapi import HTTPException
from app.auth import dependencies
from app.core.security import create_access_token, create_refresh_token
from tests.conftest import make_user
def test_get_db_returns_app_state_db():
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace(db="db-instance")))
result = asyncio.run(dependencies.get_db(request))
assert result == "db-instance"
def test_get_db_raises_when_database_missing():
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace()))
with pytest.raises(HTTPException) as exc_info:
asyncio.run(dependencies.get_db(request))
assert exc_info.value.status_code == 503
assert exc_info.value.detail == "Database not initialized"
def test_get_current_user_accepts_valid_access_token():
repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=make_user()))
result = asyncio.run(
dependencies.get_current_user(
token=create_access_token("tester"),
user_repo=repo,
)
)
assert result.username == "tester"
repo.get_user_by_username.assert_awaited_once_with("tester")
def test_get_current_user_rejects_refresh_token():
repo = SimpleNamespace(get_user_by_username=AsyncMock())
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
dependencies.get_current_user(
token=create_refresh_token("tester"),
user_repo=repo,
)
)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid token type. Access token required."
repo.get_user_by_username.assert_not_awaited()
def test_get_current_user_rejects_missing_user():
repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=None))
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
dependencies.get_current_user(
token=create_access_token("ghost"),
user_repo=repo,
)
)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Could not validate credentials"
def test_get_current_active_user_rejects_inactive_user():
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
dependencies.get_current_active_user(
current_user=make_user(is_active=False),
)
)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "Inactive user"
def test_get_current_superuser_rejects_non_superuser():
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
dependencies.get_current_superuser(
current_user=make_user(is_superuser=False),
)
)
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "Not enough privileges. Superuser access required."
+56
View File
@@ -0,0 +1,56 @@
import asyncio
import pytest
from fastapi import HTTPException
from app.auth import permissions
from app.domain.models.role import UserRole
from tests.conftest import make_user
def test_require_role_allows_higher_privilege_user():
checker = permissions.require_role(UserRole.OPERATOR)
result = asyncio.run(checker(current_user=make_user(role=UserRole.ADMIN)))
assert result.role == UserRole.ADMIN
def test_require_role_rejects_insufficient_role():
checker = permissions.require_role(UserRole.ADMIN)
with pytest.raises(HTTPException) as exc_info:
asyncio.run(checker(current_user=make_user(role=UserRole.USER)))
assert exc_info.value.status_code == 403
assert "Required role: ADMIN" in exc_info.value.detail
def test_check_resource_owner_allows_admin():
assert permissions.check_resource_owner(
99,
make_user(id=1, role=UserRole.ADMIN),
) is True
def test_check_resource_owner_allows_owner():
assert permissions.check_resource_owner(
7,
make_user(id=7, role=UserRole.USER),
) is True
def test_check_resource_owner_rejects_other_user():
assert permissions.check_resource_owner(
7,
make_user(id=8, role=UserRole.USER),
) is False
def test_require_owner_or_admin_rejects_other_user():
checker = permissions.require_owner_or_admin(7)
with pytest.raises(HTTPException) as exc_info:
asyncio.run(checker(current_user=make_user(id=8, role=UserRole.USER)))
assert exc_info.value.status_code == 403
assert exc_info.value.detail == "You don't have permission to access this resource"
+19 -18
View File
@@ -1,9 +1,8 @@
import asyncio
from datetime import datetime, timezone
import importlib.util
from pathlib import Path
import pytest
def _load_scada_repository():
module_path = (
@@ -50,18 +49,19 @@ class _FakeConnection:
return self.cursor_instance
@pytest.mark.asyncio
async def test_update_scada_field_inserts_when_update_hits_no_rows():
def test_update_scada_field_inserts_when_update_hits_no_rows():
ScadaRepository = _load_scada_repository()
conn = _FakeConnection(initial_rowcount=0)
point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)
await ScadaRepository.update_scada_field(
conn,
point_time,
"170490",
"cleaned_value",
26.5,
asyncio.run(
ScadaRepository.update_scada_field(
conn,
point_time,
"170490",
"cleaned_value",
26.5,
)
)
assert len(conn.cursor_instance.calls) == 2
@@ -69,18 +69,19 @@ async def test_update_scada_field_inserts_when_update_hits_no_rows():
assert "INSERT INTO scada.scada_data" in conn.cursor_instance.calls[1][0]
@pytest.mark.asyncio
async def test_update_scada_field_skips_insert_when_update_succeeds():
def test_update_scada_field_skips_insert_when_update_succeeds():
ScadaRepository = _load_scada_repository()
conn = _FakeConnection(initial_rowcount=1)
point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)
await ScadaRepository.update_scada_field(
conn,
point_time,
"170490",
"cleaned_value",
26.5,
asyncio.run(
ScadaRepository.update_scada_field(
conn,
point_time,
"170490",
"cleaned_value",
26.5,
)
)
assert len(conn.cursor_instance.calls) == 1
+124
View File
@@ -0,0 +1,124 @@
import asyncio
from unittest.mock import AsyncMock
import pytest
from app.domain.models.role import UserRole
from app.domain.schemas.user import UserCreate, UserUpdate
from app.infra.db.metadb.repositories.user_repository import UserRepository
from tests.conftest import FakeCursor, FakeDB
def _user_row(**overrides):
base = {
"id": 1,
"username": "tester",
"email": "tester@example.com",
"hashed_password": "hashed-password",
"role": "USER",
"is_active": True,
"is_superuser": False,
"created_at": "2025-01-01T00:00:00+00:00",
"updated_at": "2025-01-01T00:00:00+00:00",
}
base.update(overrides)
return base
def test_create_user_hashes_password_and_returns_model(monkeypatch):
cursor = FakeCursor(fetchone_results=[_user_row()])
repo = UserRepository(FakeDB(cursor))
monkeypatch.setattr(
"app.infra.db.metadb.repositories.user_repository.get_password_hash",
lambda password: f"hashed::{password}",
)
result = asyncio.run(
repo.create_user(
UserCreate(
username="tester",
email="tester@example.com",
password="secret123",
)
)
)
assert result is not None
assert result.username == "tester"
assert cursor.executed[0][1]["hashed_password"] == "hashed::secret123"
def test_update_user_without_fields_returns_existing_user(monkeypatch):
repo = UserRepository(FakeDB(FakeCursor()))
existing_user = AsyncMock(return_value="existing")
monkeypatch.setattr(repo, "get_user_by_id", existing_user)
result = asyncio.run(repo.update_user(1, UserUpdate()))
assert result == "existing"
existing_user.assert_awaited_once_with(1)
def test_update_user_builds_dynamic_query(monkeypatch):
cursor = FakeCursor(fetchone_results=[_user_row(role="ADMIN", email="new@example.com")])
repo = UserRepository(FakeDB(cursor))
monkeypatch.setattr(
"app.infra.db.metadb.repositories.user_repository.get_password_hash",
lambda password: f"hashed::{password}",
)
result = asyncio.run(
repo.update_user(
1,
UserUpdate(
email="new@example.com",
password="new-secret",
role=UserRole.ADMIN,
is_active=False,
),
),
)
assert result is not None
query, params = cursor.executed[0]
assert "email = %(email)s" in query
assert "hashed_password = %(hashed_password)s" in query
assert "role = %(role)s" in query
assert "is_active = %(is_active)s" in query
assert params["hashed_password"] == "hashed::new-secret"
assert params["role"] == "ADMIN"
assert params["is_active"] is False
def test_delete_user_returns_false_when_execute_raises():
cursor = FakeCursor()
cursor.execute = AsyncMock(side_effect=RuntimeError("boom"))
repo = UserRepository(FakeDB(cursor))
result = asyncio.run(repo.delete_user(1))
assert result is False
def test_user_exists_short_circuits_without_filters():
cursor = FakeCursor()
repo = UserRepository(FakeDB(cursor))
result = asyncio.run(repo.user_exists())
assert result is False
assert cursor.executed == []
def test_user_exists_checks_username_or_email():
cursor = FakeCursor(fetchone_results=[{"exists": True}])
repo = UserRepository(FakeDB(cursor))
result = asyncio.run(
repo.user_exists(username="tester", email="tester@example.com")
)
assert result is True
query, params = cursor.executed[0]
assert "username = %(username)s OR email = %(email)s" in query
assert params == {"username": "tester", "email": "tester@example.com"}
+76
View File
@@ -0,0 +1,76 @@
# `app/api/v1/endpoints/` 失效 API 排查与修正
排查范围:`app/api/v1/endpoints/`
结论:本次共确认 5 个问题接口,处理结果如下:
- **已删除 4 个未实现坏接口**
- **已修正 1 个签名失配接口**
> 路由统一前缀来自 `app/main.py:71`,以下完整路径均以 `/api/v1` 开头。
## 处理结果
| Method | API | 原问题 | 处理结果 |
| --- | --- | --- | --- |
| GET | `/api/v1/calculateregion/` | 调用时 `NameError`,底层无 `calculate_region` 实现 | **已删除** |
| GET | `/api/v1/getallregions/` | 调用时 `NameError`,底层无 `get_all_regions` 实现 | **已删除** |
| POST | `/api/v1/generateregion/` | 调用时 `NameError`,底层无 `generate_region` 实现 | **已删除** |
| GET | `/api/v1/calculatedistrictmeteringarea/` | 调用时 `NameError`,仍指向已废弃旧 DMA 入口 | **已删除** |
| GET | `/api/v1/calculateservicearea/` | endpoint 传 `time_index`,实现只接受 `name` | **已修正**,现返回全部时间步结果 |
## 删除原因
### 1. region 相关 3 个接口
以下能力在当前 `wndb` / `tjnetwork` 中均不存在:
- `calculate_region`
- `get_all_regions`
- `generate_region`
`app/native/wndb/__init__.py` 当前只提供 region CRUD 和 util 能力,不提供 region 计算或批量查询能力。因此这 3 个接口继续保留只会在运行时失败。
### 2. DMA 旧入口
旧接口 `calculate_district_metering_area(...)` 已不存在,当前只保留 3 个明确变体:
- `/api/v1/calculatedistrictmeteringareafornodes/`
- `/api/v1/calculatedistrictmeteringareaforregion/`
- `/api/v1/calculatedistrictmeteringareafornetwork/`
因此旧入口 `/api/v1/calculatedistrictmeteringarea/` 已删除,避免前端继续误用历史接口。
## 修正内容
### `GET /api/v1/calculateservicearea/`
原接口问题:
- endpoint 定义保留 `time_index`
- 实际实现 `calculate_service_area(name)` 只接收 `network/name`
- 调用时会触发参数数量不匹配
本次修正后:
- 移除 `time_index` 查询参数
- 返回类型改为 `list[dict[str, list[str]]]`
- 接口语义改为:**返回全部时间步的服务区计算结果**
## 当前可用替代接口
DMA 计算请使用:
- `/api/v1/calculatedistrictmeteringareafornodes/`
- `/api/v1/calculatedistrictmeteringareaforregion/`
- `/api/v1/calculatedistrictmeteringareafornetwork/`
服务区计算请使用:
- `/api/v1/calculateservicearea/`
现在返回全部时间步结果,不再接收 `time_index`
## 变更文件
- `app/api/v1/endpoints/network/regions.py`
- `失效API排查.md`