From 2317f4d527d3ec4720ec2ad18e020aaf764c3e33 Mon Sep 17 00:00:00 2001 From: Jiang Date: Thu, 21 May 2026 15:32:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=20API=20=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B=EF=BC=8C=E4=BF=AE=E5=A4=8D=E5=A4=B1=E6=95=88?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/v1/endpoints/network/regions.py | 76 +------- tests/api/test_audit_endpoints.py | 91 +++++++++ tests/api/test_auth_endpoints.py | 139 ++++++++++++++ tests/api/test_project_endpoints.py | 152 +++++++++++++++ tests/api/test_regions_endpoints.py | 154 ++++++++++++++++ tests/api/test_simulation_endpoints.py | 175 ++++++++++++++++++ tests/api/test_user_management_endpoints.py | 95 ++++++++++ tests/auth/test_security.py | 36 ++++ tests/conftest.py | 195 +++++++++++++++++++- tests/unit/test_audit_repository.py | 79 ++++++++ tests/unit/test_auth_dependencies.py | 97 ++++++++++ tests/unit/test_permissions.py | 56 ++++++ tests/unit/test_scada_repository.py | 37 ++-- tests/unit/test_user_repository.py | 124 +++++++++++++ 失效API排查.md | 76 ++++++++ 15 files changed, 1486 insertions(+), 96 deletions(-) create mode 100644 tests/api/test_audit_endpoints.py create mode 100644 tests/api/test_auth_endpoints.py create mode 100644 tests/api/test_project_endpoints.py create mode 100644 tests/api/test_regions_endpoints.py create mode 100644 tests/api/test_simulation_endpoints.py create mode 100644 tests/api/test_user_management_endpoints.py create mode 100644 tests/auth/test_security.py create mode 100644 tests/unit/test_audit_repository.py create mode 100644 tests/unit/test_auth_dependencies.py create mode 100644 tests/unit/test_permissions.py create mode 100644 tests/unit/test_user_repository.py create mode 100644 失效API排查.md diff --git a/app/api/v1/endpoints/network/regions.py b/app/api/v1/endpoints/network/regions.py index 61c84f0..4d1b564 100644 --- a/app/api/v1/endpoints/network/regions.py +++ b/app/api/v1/endpoints/network/regions.py @@ -7,11 +7,9 @@ from app.services.tjnetwork import ( add_region, add_service_area, add_virtual_district, - # calculate_district_metering_area, calculate_district_metering_area_for_network, calculate_district_metering_area_for_nodes, calculate_district_metering_area_for_region, - # calculate_region, calculate_service_area, calculate_virtual_district, delete_district_metering_area, @@ -19,13 +17,11 @@ from app.services.tjnetwork import ( delete_service_area, delete_virtual_district, generate_district_metering_area, - # generate_region, generate_service_area, generate_sub_district_metering_area, generate_virtual_district, get_all_district_metering_area_ids, get_all_district_metering_areas, - # get_all_regions, get_all_service_areas, get_all_virtual_districts, get_district_metering_area, @@ -48,18 +44,6 @@ router = APIRouter() # region 32 ############################################################ -@router.get( - "/calculateregion/", - summary="计算区域", - description="计算指定水网在指定时间步长的区域分区" -) -async def fastapi_calculate_region( - network: str = Query(..., description="管网名称(或数据库名称)"), - time_index: int = Query(..., description="时间步长索引", ge=0) -) -> dict[str, Any]: - """计算区域分区。""" - return calculate_region(network, time_index) - @router.get( "/getregionschema/", summary="获取区域属性架构", @@ -125,62 +109,11 @@ async def fastapi_delete_region( props = await req.json() return delete_region(network, ChangeSet(props)) -@router.get( - "/getallregions/", - summary="获取所有区域", - description="获取指定水网中的所有区域信息" -) -async def fastapi_get_all_regions( - network: str = Query(..., description="管网名称(或数据库名称)") -) -> list[dict[str, Any]]: - """获取所有区域的信息列表。""" - return get_all_regions(network) - -@router.post( - "/generateregion/", - response_model=None, - summary="生成区域分区", - description="根据参数自动生成水网的区域分区" -) -async def fastapi_generate_region( - network: str = Query(..., description="管网名称(或数据库名称)"), - inflate_delta: float = Query(..., description="膨胀参数") -) -> ChangeSet: - """生成区域分区。""" - return generate_region(network, inflate_delta) - ############################################################ # district_metering_area 33 ############################################################ -@router.get( - "/calculatedistrictmeteringarea/", - summary="计算DMA分区", - description="计算指定节点集的区域计量(DMA)分区方案" -) -async def fastapi_calculate_district_metering_area( - network: str = Query(..., description="管网名称(或数据库名称)"), - req: Request = None -) -> list[list[str]]: - """ - 计算DMA分区。 - - 请求体格式: - { - "nodes": 节点ID列表(list[str]), - "part_count": 分区数量(int), - "part_type": 分区类型(int) - } - """ - props = await req.json() - nodes = props["nodes"] - part_count = props["part_count"] - part_type = props["part_type"] - return calculate_district_metering_area( - network, nodes, part_count, part_type - ) - @router.get( "/calculatedistrictmeteringareaforregion/", summary="计算区域内DMA分区", @@ -368,14 +301,13 @@ async def fastapi_generate_sub_district_metering_area( @router.get( "/calculateservicearea/", summary="计算服务区", - description="计算指定水网在指定时间步长的服务区分区" + description="计算指定水网的服务区分区,返回全部时间步结果" ) async def fastapi_calculate_service_area( network: str = Query(..., description="管网名称(或数据库名称)"), - time_index: int = Query(..., description="时间步长索引", ge=0) -) -> dict[str, Any]: - """计算服务区分区。""" - return calculate_service_area(network, time_index) +) -> list[dict[str, list[str]]]: + """计算服务区分区,返回全部时间步结果。""" + return calculate_service_area(network) @router.get( "/getserviceareaschema/", diff --git a/tests/api/test_audit_endpoints.py b/tests/api/test_audit_endpoints.py new file mode 100644 index 0000000..d043800 --- /dev/null +++ b/tests/api/test_audit_endpoints.py @@ -0,0 +1,91 @@ +from unittest.mock import AsyncMock + +from fastapi.testclient import TestClient + +from app.api.v1.endpoints import audit as audit_endpoint +from app.auth.metadata_dependencies import ( + get_current_metadata_admin, + get_current_metadata_user, +) +from tests.conftest import build_test_app, make_audit_log + + +def _build_client(repo, *, metadata_admin=None, metadata_user=None) -> TestClient: + app = build_test_app(audit_endpoint.router, "/audit") + app.dependency_overrides[audit_endpoint.get_audit_repository] = lambda: repo + if metadata_admin is not None: + app.dependency_overrides[get_current_metadata_admin] = lambda: metadata_admin + if metadata_user is not None: + app.dependency_overrides[get_current_metadata_user] = lambda: metadata_user + return TestClient(app) + + +def test_get_audit_logs_passes_filters(): + repo = type( + "Repo", + (), + { + "get_logs": AsyncMock(return_value=[make_audit_log(action="LOGIN")]), + "get_log_count": AsyncMock(), + }, + )() + client = _build_client(repo, metadata_admin=object()) + + response = client.get( + "/audit/logs", + params={ + "action": "LOGIN", + "resource_type": "user", + "skip": 2, + "limit": 5, + }, + ) + + assert response.status_code == 200 + assert response.json()[0]["action"] == "LOGIN" + repo.get_logs.assert_awaited_once() + kwargs = repo.get_logs.await_args.kwargs + assert kwargs["action"] == "LOGIN" + assert kwargs["resource_type"] == "user" + assert kwargs["skip"] == 2 + assert kwargs["limit"] == 5 + + +def test_get_audit_logs_count_returns_count_payload(): + repo = type( + "Repo", + (), + { + "get_logs": AsyncMock(), + "get_log_count": AsyncMock(return_value=7), + }, + )() + client = _build_client(repo, metadata_admin=object()) + + response = client.get("/audit/logs/count", params={"action": "DELETE_USER"}) + + assert response.status_code == 200 + assert response.json() == {"count": 7} + repo.get_log_count.assert_awaited_once() + assert repo.get_log_count.await_args.kwargs["action"] == "DELETE_USER" + + +def test_get_my_audit_logs_forces_current_user_id(): + current_user = type("User", (), {"id": make_audit_log().user_id})() + repo = type( + "Repo", + (), + { + "get_logs": AsyncMock(return_value=[make_audit_log(user_id=current_user.id)]), + "get_log_count": AsyncMock(), + }, + )() + client = _build_client(repo, metadata_user=current_user) + + response = client.get("/audit/logs/my", params={"limit": 3}) + + assert response.status_code == 200 + repo.get_logs.assert_awaited_once() + kwargs = repo.get_logs.await_args.kwargs + assert kwargs["user_id"] == current_user.id + assert kwargs["limit"] == 3 diff --git a/tests/api/test_auth_endpoints.py b/tests/api/test_auth_endpoints.py new file mode 100644 index 0000000..144e442 --- /dev/null +++ b/tests/api/test_auth_endpoints.py @@ -0,0 +1,139 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from fastapi.testclient import TestClient + +from app.api.v1.endpoints import auth as auth_endpoint +from app.auth.dependencies import get_current_active_user, get_user_repository +from app.core.security import create_access_token, create_refresh_token, get_password_hash +from tests.conftest import build_test_app, make_user + + +def _build_client(repo, current_user=None) -> TestClient: + app = build_test_app(auth_endpoint.router, "/api/v1/auth") + app.dependency_overrides[get_user_repository] = lambda: repo + if current_user is not None: + app.dependency_overrides[get_current_active_user] = lambda: current_user + return TestClient(app) + + +def test_register_success(): + repo = SimpleNamespace( + user_exists=AsyncMock(side_effect=[False, False]), + create_user=AsyncMock(return_value=make_user()), + ) + client = _build_client(repo) + + response = client.post( + "/api/v1/auth/register", + json={ + "username": "tester", + "email": "tester@example.com", + "password": "secret123", + }, + ) + + assert response.status_code == 201 + assert response.json()["username"] == "tester" + + +def test_register_rejects_duplicate_username(): + repo = SimpleNamespace( + user_exists=AsyncMock(side_effect=[True]), + create_user=AsyncMock(), + ) + client = _build_client(repo) + + response = client.post( + "/api/v1/auth/register", + json={ + "username": "tester", + "email": "tester@example.com", + "password": "secret123", + }, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Username already registered" + repo.create_user.assert_not_awaited() + + +def test_login_supports_email_lookup(): + hashed_password = get_password_hash("secret123") + repo = SimpleNamespace( + get_user_by_username=AsyncMock(return_value=None), + get_user_by_email=AsyncMock( + return_value=make_user( + email="tester@example.com", + hashed_password=hashed_password, + ) + ), + ) + client = _build_client(repo) + + response = client.post( + "/api/v1/auth/login", + data={"username": "tester@example.com", "password": "secret123"}, + ) + + assert response.status_code == 200 + assert response.json()["token_type"] == "bearer" + repo.get_user_by_email.assert_awaited_once_with("tester@example.com") + + +def test_login_simple_uses_query_params(): + hashed_password = get_password_hash("secret123") + repo = SimpleNamespace( + get_user_by_username=AsyncMock( + return_value=make_user(hashed_password=hashed_password) + ), + get_user_by_email=AsyncMock(), + ) + client = _build_client(repo) + + response = client.post( + "/api/v1/auth/login/simple", + params={"username": "tester", "password": "secret123"}, + ) + + assert response.status_code == 200 + assert response.json()["token_type"] == "bearer" + + +def test_me_returns_current_user_info(): + client = _build_client(SimpleNamespace(), current_user=make_user(username="alice")) + + response = client.get("/api/v1/auth/me") + + assert response.status_code == 200 + assert response.json()["username"] == "alice" + + +def test_refresh_rejects_access_token(): + repo = SimpleNamespace(get_user_by_username=AsyncMock()) + client = _build_client(repo) + + response = client.post( + "/api/v1/auth/refresh", + params={"refresh_token": create_access_token("tester")}, + ) + + assert response.status_code == 401 + + +def test_refresh_success_returns_new_access_token(): + repo = SimpleNamespace( + get_user_by_username=AsyncMock(return_value=make_user()), + ) + client = _build_client(repo) + refresh_token = create_refresh_token("tester") + + response = client.post( + "/api/v1/auth/refresh", + params={"refresh_token": refresh_token}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["refresh_token"] == refresh_token + assert payload["token_type"] == "bearer" diff --git a/tests/api/test_project_endpoints.py b/tests/api/test_project_endpoints.py new file mode 100644 index 0000000..374f4ad --- /dev/null +++ b/tests/api/test_project_endpoints.py @@ -0,0 +1,152 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock +from uuid import uuid4 + +from fastapi.testclient import TestClient + +from tests.conftest import build_test_app, install_stub, load_module_from_path + + +class DummyChangeSet: + def __init__(self, operations=None): + if operations is None: + self.operations = [] + elif isinstance(operations, dict): + self.operations = [operations] + else: + self.operations = operations + + +def _load_project_module(monkeypatch): + install_stub(monkeypatch, "app.services", package=True) + install_stub( + monkeypatch, + "app.services.project_info", + {}, + ) + install_stub( + monkeypatch, + "app.services.tjnetwork", + { + "ChangeSet": DummyChangeSet, + "list_project": lambda: ["demo"], + "have_project": lambda network: network == "demo", + "create_project": lambda network: None, + "delete_project": lambda network: None, + "is_project_open": lambda network: False, + "open_project": lambda network: None, + "close_project": lambda network: None, + "copy_project": lambda source, target: None, + "import_inp": lambda network, cs: {"ok": True}, + "export_inp": lambda network, version: DummyChangeSet({"kind": "export"}), + "read_inp": lambda network, inp: True, + "dump_inp": lambda network, inp: True, + "get_all_vertices": lambda network: [], + "get_all_scada_elements": lambda network: [], + "get_all_district_metering_areas": lambda network: [], + "get_all_service_areas": lambda network: [], + "get_all_virtual_districts": lambda network: [], + "get_extension_data": lambda network, key: None, + "convert_inp_v3_to_v2": lambda inp: DummyChangeSet({"inp": inp}), + }, + ) + install_stub( + monkeypatch, + "app.auth.project_dependencies", + {"get_metadata_repository": lambda: None}, + ) + install_stub( + monkeypatch, + "app.infra.db.postgresql.database", + {"get_database_instance": lambda network: None}, + ) + install_stub( + monkeypatch, + "app.infra.db.timescaledb.database", + {"get_database_instance": lambda network: None}, + ) + return load_module_from_path( + "tests_project_endpoints_module", + "app/api/v1/endpoints/project.py", + ) + + +def test_project_info_returns_404_when_missing(monkeypatch): + module = _load_project_module(monkeypatch) + repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=None)) + app = build_test_app(module.router, "/api/v1") + app.dependency_overrides[module.get_metadata_repository] = lambda: repo + client = TestClient(app) + + response = client.get("/api/v1/project_info/", params={"network": "missing"}) + + assert response.status_code == 404 + assert response.json()["detail"] == "Project missing not found" + + +def test_project_info_returns_geoserver_payload(monkeypatch): + module = _load_project_module(monkeypatch) + detail = SimpleNamespace( + project_id=uuid4(), + name="Demo Project", + code="demo", + description="desc", + gs_workspace="ws", + map_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4}, + status="active", + geoserver=SimpleNamespace( + gs_base_url="http://gs", + gs_admin_user="admin", + gs_datastore_name="store", + default_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4}, + srid=4326, + ), + ) + repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=detail)) + app = build_test_app(module.router, "/api/v1") + app.dependency_overrides[module.get_metadata_repository] = lambda: repo + client = TestClient(app) + + response = client.get("/api/v1/project_info/", params={"network": "demo"}) + + assert response.status_code == 200 + payload = response.json() + assert payload["code"] == "demo" + assert payload["geoserver"]["gs_base_url"] == "http://gs" + + +def test_open_project_returns_network_even_when_db_connection_fails(monkeypatch): + module = _load_project_module(monkeypatch) + called = [] + + monkeypatch.setattr(module, "open_project", lambda network: called.append(network)) + + async def failing_get_pg_db(network): + raise RuntimeError("db down") + + monkeypatch.setattr(module, "get_pg_db", failing_get_pg_db) + client = TestClient(build_test_app(module.router, "/api/v1")) + + response = client.post("/api/v1/openproject/", params={"network": "demo"}) + + assert response.status_code == 200 + assert response.json() == "demo" + assert called == ["demo"] + + +def test_project_lock_lifecycle(monkeypatch): + module = _load_project_module(monkeypatch) + module.lockedPrjs.clear() + client = TestClient(build_test_app(module.router, "/api/v1")) + + first_lock = client.post("/api/v1/lockproject/", params={"network": "demo"}) + second_lock = client.post("/api/v1/lockproject/", params={"network": "demo"}) + locked_by_me = client.get("/api/v1/isprojectlockedbyme/", params={"network": "demo"}) + unlock = client.post("/api/v1/unlockproject/", params={"network": "demo"}) + locked = client.get("/api/v1/isprojectlocked/", params={"network": "demo"}) + + assert first_lock.json() == 0 + assert second_lock.json() == 1 + assert locked_by_me.json() is True + assert unlock.json() is True + assert locked.json() is False diff --git a/tests/api/test_regions_endpoints.py b/tests/api/test_regions_endpoints.py new file mode 100644 index 0000000..b51a9cc --- /dev/null +++ b/tests/api/test_regions_endpoints.py @@ -0,0 +1,154 @@ +from typing import Any + +from fastapi.testclient import TestClient + +from tests.conftest import build_test_app, install_stub, load_module_from_path + + +class DummyChangeSet: + def __init__(self, operations=None): + if operations is None: + self.operations = [] + elif isinstance(operations, dict): + self.operations = [operations] + else: + self.operations = operations + + +def _noop(*args, **kwargs): + return None + + +def _load_regions_module(monkeypatch): + install_stub(monkeypatch, "app.services", package=True) + install_stub( + monkeypatch, + "app.services.tjnetwork", + { + "Any": Any, + "ChangeSet": DummyChangeSet, + "add_district_metering_area": _noop, + "add_region": _noop, + "add_service_area": _noop, + "add_virtual_district": _noop, + "calculate_district_metering_area_for_network": lambda *args, **kwargs: [], + "calculate_district_metering_area_for_nodes": lambda *args, **kwargs: [], + "calculate_district_metering_area_for_region": lambda *args, **kwargs: [], + "calculate_service_area": lambda network: [], + "calculate_virtual_district": lambda *args, **kwargs: {}, + "delete_district_metering_area": _noop, + "delete_region": _noop, + "delete_service_area": _noop, + "delete_virtual_district": _noop, + "generate_district_metering_area": _noop, + "generate_service_area": _noop, + "generate_sub_district_metering_area": _noop, + "generate_virtual_district": _noop, + "get_all_district_metering_area_ids": lambda network: [], + "get_all_district_metering_areas": lambda network: [], + "get_all_service_areas": lambda network: [], + "get_all_virtual_districts": lambda network: [], + "get_district_metering_area": lambda network, area_id: {}, + "get_district_metering_area_schema": lambda network: {}, + "get_region": lambda network, region_id: {}, + "get_region_schema": lambda network: {}, + "get_service_area": lambda network, area_id: {}, + "get_service_area_schema": lambda network: {}, + "get_virtual_district": lambda network, area_id: {}, + "get_virtual_district_schema": lambda network: {}, + "set_district_metering_area": _noop, + "set_region": _noop, + "set_service_area": _noop, + "set_virtual_district": _noop, + }, + ) + return load_module_from_path( + "tests_regions_endpoints_module", + "app/api/v1/endpoints/network/regions.py", + ) + + +def test_removed_routes_are_absent_and_return_404(monkeypatch): + module = _load_regions_module(monkeypatch) + client = TestClient(build_test_app(module.router, "/api/v1")) + + openapi = client.get("/openapi.json").json() + + assert "/api/v1/calculateregion/" not in openapi["paths"] + assert "/api/v1/getallregions/" not in openapi["paths"] + assert "/api/v1/generateregion/" not in openapi["paths"] + assert "/api/v1/calculatedistrictmeteringarea/" not in openapi["paths"] + assert client.get("/api/v1/calculateregion/", params={"network": "demo", "time_index": 0}).status_code == 404 + assert client.get("/api/v1/calculatedistrictmeteringarea/", params={"network": "demo"}).status_code == 404 + + +def test_calculate_service_area_contract_uses_only_network(monkeypatch): + module = _load_regions_module(monkeypatch) + calls = [] + monkeypatch.setattr( + module, + "calculate_service_area", + lambda network: calls.append(network) or [{"source-1": ["n1", "n2"]}], + ) + client = TestClient(build_test_app(module.router, "/api/v1")) + + response = client.get( + "/api/v1/calculateservicearea/", + params={"network": "demo", "time_index": 5}, + ) + schema = client.get("/openapi.json").json() + + assert response.status_code == 200 + assert response.json() == [{"source-1": ["n1", "n2"]}] + assert calls == ["demo"] + parameter_names = [ + item["name"] + for item in schema["paths"]["/api/v1/calculateservicearea/"]["get"]["parameters"] + ] + assert parameter_names == ["network"] + + +def test_add_district_metering_area_converts_boundary_to_tuples(monkeypatch): + module = _load_regions_module(monkeypatch) + captured = {} + + def fake_add(network, change_set): + captured["network"] = network + captured["boundary"] = change_set.operations[0]["boundary"] + return {"ok": True} + + monkeypatch.setattr(module, "add_district_metering_area", fake_add) + client = TestClient(build_test_app(module.router, "/api/v1")) + + response = client.post( + "/api/v1/adddistrictmeteringarea/", + params={"network": "demo"}, + json={"id": "dma-1", "boundary": [[1, 2], [3, 4], [1, 2]]}, + ) + + assert response.status_code == 200 + assert captured == { + "network": "demo", + "boundary": [(1, 2), (3, 4), (1, 2)], + } + + +def test_generate_virtual_district_reads_centers_from_body(monkeypatch): + module = _load_regions_module(monkeypatch) + captured = {} + + def fake_generate(network, centers, inflate_delta): + captured["args"] = (network, centers, inflate_delta) + return {"generated": True} + + monkeypatch.setattr(module, "generate_virtual_district", fake_generate) + client = TestClient(build_test_app(module.router, "/api/v1")) + + response = client.post( + "/api/v1/generatevirtualdistrict/", + params={"network": "demo", "inflate_delta": 0.75}, + json={"centers": ["J1", "J2"]}, + ) + + assert response.status_code == 200 + assert captured["args"] == ("demo", ["J1", "J2"], 0.75) diff --git a/tests/api/test_simulation_endpoints.py b/tests/api/test_simulation_endpoints.py new file mode 100644 index 0000000..e9cdb1c --- /dev/null +++ b/tests/api/test_simulation_endpoints.py @@ -0,0 +1,175 @@ +from pathlib import Path + +from fastapi.testclient import TestClient + +from tests.conftest import build_test_app, install_stub, load_module_from_path + + +def _load_simulation_module(monkeypatch): + install_stub(monkeypatch, "app.services", package=True) + install_stub( + monkeypatch, + "app.services.simulation", + {"run_simulation": lambda **kwargs: None}, + ) + install_stub(monkeypatch, "app.services.globals", {}) + install_stub( + monkeypatch, + "app.services.tjnetwork", + { + "run_project": lambda network: "report", + "run_project_return_dict": lambda network: {"output": {}, "report": "ok"}, + "run_inp": lambda network: "inp-report", + "dump_output": lambda output: f"dump::{output}", + }, + ) + install_stub(monkeypatch, "app.algorithms", package=True) + install_stub(monkeypatch, "app.algorithms.simulation", package=True) + install_stub( + monkeypatch, + "app.algorithms.simulation.scenarios", + { + "burst_analysis": lambda *args, **kwargs: "burst", + "valve_close_analysis": lambda *args, **kwargs: "valve", + "flushing_analysis": lambda *args, **kwargs: "flush", + "contaminant_simulation": lambda *args, **kwargs: "contaminant", + "age_analysis": lambda *args, **kwargs: "age", + "pressure_regulation": lambda *args, **kwargs: "pressure", + }, + ) + install_stub( + monkeypatch, + "app.algorithms.sensor", + { + "pressure_sensor_placement_sensitivity": lambda *args, **kwargs: [], + "pressure_sensor_placement_kmeans": lambda *args, **kwargs: [], + }, + ) + install_stub( + monkeypatch, + "app.services.network_import", + {"network_update": lambda *args, **kwargs: "updated"}, + ) + install_stub( + monkeypatch, + "app.services.simulation_ops", + { + "project_management": lambda *args, **kwargs: "managed", + "scheduling_simulation": lambda *args, **kwargs: "scheduled", + "daily_scheduling_simulation": lambda *args, **kwargs: "daily", + }, + ) + install_stub( + monkeypatch, + "app.services.valve_isolation", + {"analyze_valve_isolation": lambda *args, **kwargs: {}}, + ) + return load_module_from_path( + "tests_simulation_endpoints_module", + "app/api/v1/endpoints/simulation.py", + ) + + +def test_run_project_endpoint_returns_plain_text(monkeypatch): + module = _load_simulation_module(monkeypatch) + monkeypatch.setattr(module, "run_project", lambda network: f"report::{network}") + client = TestClient(build_test_app(module.router, "/api/v1")) + + response = client.get("/api/v1/runproject/", params={"network": "demo"}) + + assert response.status_code == 200 + assert response.text == "report::demo" + + +def test_scheduling_analysis_maps_request_body(monkeypatch): + module = _load_simulation_module(monkeypatch) + captured = {} + + def fake_schedule(network, start_time, pump_control, tank_id, water_plant_output_id, time_delta): + captured["args"] = ( + network, + start_time, + pump_control, + tank_id, + water_plant_output_id, + time_delta, + ) + return "scheduled" + + monkeypatch.setattr(module, "scheduling_simulation", fake_schedule) + client = TestClient(build_test_app(module.router, "/api/v1")) + + response = client.post( + "/api/v1/scheduling_analysis/", + json={ + "network": "demo", + "start_time": "2025-01-01T08:00:00+08:00", + "pump_control": {"P1": [1, 0, 1]}, + "tank_id": "T1", + "water_plant_output_id": "R1", + }, + ) + + assert response.status_code == 200 + assert response.json() == "scheduled" + assert captured["args"] == ( + "demo", + "2025-01-01T08:00:00+08:00", + {"P1": [1, 0, 1]}, + "T1", + "R1", + 300, + ) + + +def test_project_management_maps_named_arguments(monkeypatch): + module = _load_simulation_module(monkeypatch) + captured = {} + + def fake_project_management(**kwargs): + captured.update(kwargs) + return "managed" + + monkeypatch.setattr(module, "project_management", fake_project_management) + client = TestClient(build_test_app(module.router, "/api/v1")) + + response = client.post( + "/api/v1/project_management/", + json={ + "network": "demo", + "start_time": "2025-01-01T08:00:00+08:00", + "pump_control": {"P1": [1]}, + "tank_init_level": {"T1": 10.0}, + "region_demand": {"R1": 20.0}, + }, + ) + + assert response.status_code == 200 + assert response.json() == "managed" + assert captured == { + "prj_name": "demo", + "start_datetime": "2025-01-01T08:00:00+08:00", + "pump_control": {"P1": [1]}, + "tank_initial_level_control": {"T1": 10.0}, + "region_demand_control": {"R1": 20.0}, + } + + +def test_network_update_surfaces_service_error(monkeypatch, tmp_path): + module = _load_simulation_module(monkeypatch) + monkeypatch.chdir(tmp_path) + + def boom(_path): + raise RuntimeError("write failed") + + monkeypatch.setattr(module, "network_update", boom) + client = TestClient(build_test_app(module.router, "/api/v1")) + + response = client.post( + "/api/v1/network_update/", + files={"file": ("update.txt", b"payload")}, + ) + + assert response.status_code == 500 + assert "数据库操作失败: write failed" in response.json()["detail"] + assert list(Path(tmp_path).glob("network_update_*")) diff --git a/tests/api/test_user_management_endpoints.py b/tests/api/test_user_management_endpoints.py new file mode 100644 index 0000000..8991d5a --- /dev/null +++ b/tests/api/test_user_management_endpoints.py @@ -0,0 +1,95 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from fastapi.testclient import TestClient + +from app.api.v1.endpoints import user_management as user_management_endpoint +from app.auth.dependencies import get_current_active_user, get_user_repository +from app.auth.permissions import get_current_admin +from app.domain.models.role import UserRole +from tests.conftest import build_test_app, make_user + + +def _build_client(repo, *, current_user=None, admin_user=None) -> TestClient: + app = build_test_app(user_management_endpoint.router, "/users") + app.dependency_overrides[get_user_repository] = lambda: repo + if current_user is not None: + app.dependency_overrides[get_current_active_user] = lambda: current_user + if admin_user is not None: + app.dependency_overrides[get_current_admin] = lambda: admin_user + return TestClient(app) + + +def test_list_users_requires_admin_role(): + repo = SimpleNamespace( + get_all_users=AsyncMock( + return_value=[ + make_user(id=1, username="admin", role=UserRole.ADMIN), + make_user(id=2, username="user2"), + ] + ) + ) + client = _build_client( + repo, + current_user=make_user(id=1, role=UserRole.ADMIN), + ) + + response = client.get("/users/", params={"skip": 5, "limit": 2}) + + assert response.status_code == 200 + assert len(response.json()) == 2 + repo.get_all_users.assert_awaited_once_with(skip=5, limit=2) + + +def test_get_user_rejects_non_owner_non_admin(): + repo = SimpleNamespace(get_user_by_id=AsyncMock()) + client = _build_client(repo, current_user=make_user(id=2, role=UserRole.USER)) + + response = client.get("/users/3") + + assert response.status_code == 403 + assert response.json()["detail"] == "You don't have permission to view this user" + repo.get_user_by_id.assert_not_awaited() + + +def test_update_user_blocks_role_change_for_non_admin(): + repo = SimpleNamespace( + get_user_by_id=AsyncMock(return_value=make_user(id=1)), + update_user=AsyncMock(), + ) + client = _build_client(repo, current_user=make_user(id=1, role=UserRole.USER)) + + response = client.put("/users/1", json={"role": "ADMIN"}) + + assert response.status_code == 403 + assert response.json()["detail"] == "Only admins can change user roles" + repo.update_user.assert_not_awaited() + + +def test_delete_user_blocks_self_delete_for_admin(): + admin_user = make_user(id=1, role=UserRole.ADMIN, is_superuser=True) + repo = SimpleNamespace(delete_user=AsyncMock()) + client = _build_client(repo, admin_user=admin_user) + + response = client.delete("/users/1") + + assert response.status_code == 400 + assert response.json()["detail"] == "You cannot delete your own account" + repo.delete_user.assert_not_awaited() + + +def test_activate_user_updates_active_flag(): + repo = SimpleNamespace( + update_user=AsyncMock(return_value=make_user(id=2, is_active=True)), + ) + client = _build_client( + repo, + admin_user=make_user(id=1, role=UserRole.ADMIN, is_superuser=True), + ) + + response = client.post("/users/2/activate") + + assert response.status_code == 200 + assert response.json()["is_active"] is True + user_update = repo.update_user.await_args.args[1] + assert user_update.is_active is True diff --git a/tests/auth/test_security.py b/tests/auth/test_security.py new file mode 100644 index 0000000..8953108 --- /dev/null +++ b/tests/auth/test_security.py @@ -0,0 +1,36 @@ +from jose import jwt + +from app.core.config import settings +from app.core.security import ( + create_access_token, + create_refresh_token, + get_password_hash, + verify_password, +) + + +def test_password_hash_roundtrip(): + hashed = get_password_hash("secret123") + assert hashed != "secret123" + assert verify_password("secret123", hashed) is True + assert verify_password("wrong", hashed) is False + + +def test_create_access_token_sets_access_type(): + token = create_access_token("alice") + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + + assert payload["sub"] == "alice" + assert payload["type"] == "access" + assert "exp" in payload + assert "iat" in payload + + +def test_create_refresh_token_sets_refresh_type(): + token = create_refresh_token("alice") + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + + assert payload["sub"] == "alice" + assert payload["type"] == "refresh" + assert "exp" in payload + assert "iat" in payload diff --git a/tests/conftest.py b/tests/conftest.py index b1c7e7b..d528272 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,197 @@ -import pytest -import sys +import importlib +import importlib.util import os +import sys +import types +from datetime import datetime, timezone +from pathlib import Path +from types import SimpleNamespace +from uuid import uuid4 + +import pytest +from fastapi import FastAPI # 自动添加项目根目录到路径(处理项目结构) -sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) +PROJECT_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(PROJECT_ROOT)) def run_this_test(test_file): """自定义函数:运行单个测试文件(类似pytest)""" - # 提取测试文件名(无扩展名) - test_name = os.path.splitext(os.path.basename(test_file))[0] - # 使用pytest运行(自动处理导入) pytest.main([test_file, "-v"]) + + +def build_test_app(router, prefix: str = "") -> FastAPI: + app = FastAPI() + app.include_router(router, prefix=prefix) + return app + + +def load_module_from_path(module_name: str, relative_path: str): + module_path = PROJECT_ROOT / relative_path + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) + return module + + +def install_stub(monkeypatch, name: str, attrs: dict | None = None, package: bool = False): + module = types.ModuleType(name) + if package: + module.__path__ = [] + if attrs: + for key, value in attrs.items(): + setattr(module, key, value) + + monkeypatch.setitem(sys.modules, name, module) + + parent_name, _, child_name = name.rpartition(".") + if parent_name: + parent = sys.modules.get(parent_name) + if parent is None: + try: + parent = importlib.import_module(parent_name) + except Exception: + parent = types.ModuleType(parent_name) + parent.__path__ = [] + monkeypatch.setitem(sys.modules, parent_name, parent) + setattr(parent, child_name, module) + + return module + + +class FakeCursor: + def __init__( + self, + *, + fetchone_results=None, + fetchall_results=None, + rowcount: int = 0, + rowcounts=None, + ): + self._fetchone_results = list(fetchone_results or []) + self._fetchall_results = list(fetchall_results or []) + self._rowcounts = list(rowcounts or []) + self.rowcount = rowcount + self.executed: list[tuple[str, dict | tuple | None]] = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def execute(self, query, params=None): + self.executed.append((str(query), params)) + if self._rowcounts: + self.rowcount = self._rowcounts.pop(0) + + async def fetchone(self): + if self._fetchone_results: + return self._fetchone_results.pop(0) + return None + + async def fetchall(self): + if self._fetchall_results: + return self._fetchall_results.pop(0) + return [] + + +class FakeConnection: + def __init__(self, cursor: FakeCursor): + self.cursor_instance = cursor + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + def cursor(self): + return self.cursor_instance + + +class FakeDB: + def __init__(self, cursor: FakeCursor): + self.connection = FakeConnection(cursor) + + def get_connection(self): + return self.connection + + +class FakeExecuteResult: + def __init__(self, *, rows=None, scalar_value=None): + self._rows = list(rows or []) + self._scalar_value = scalar_value + + def scalars(self): + return self + + def all(self): + return self._rows + + def scalar(self): + return self._scalar_value + + +class FakeAsyncSession: + def __init__(self, execute_results=None): + self._execute_results = list(execute_results or []) + self.executed = [] + self.added = [] + self.commit_count = 0 + self.refreshed = [] + + def add(self, obj): + self.added.append(obj) + + async def execute(self, stmt): + self.executed.append(stmt) + if self._execute_results: + return self._execute_results.pop(0) + return FakeExecuteResult() + + async def commit(self): + self.commit_count += 1 + + async def refresh(self, obj): + self.refreshed.append(obj) + + +def make_user(**overrides): + from app.domain.models.role import UserRole + from app.domain.schemas.user import UserInDB + + data = { + "id": 1, + "username": "tester", + "email": "tester@example.com", + "hashed_password": "hashed-password", + "role": UserRole.USER, + "is_active": True, + "is_superuser": False, + "created_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + "updated_at": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + data.update(overrides) + return UserInDB(**data) + + +def make_audit_log(**overrides): + data = { + "id": uuid4(), + "user_id": uuid4(), + "project_id": uuid4(), + "action": "LOGIN", + "resource_type": "user", + "resource_id": "1", + "ip_address": "127.0.0.1", + "request_method": "GET", + "request_path": "/audit/logs", + "request_data": {"ok": True}, + "response_status": 200, + "timestamp": datetime(2025, 1, 1, tzinfo=timezone.utc), + } + data.update(overrides) + return SimpleNamespace(**data) diff --git a/tests/unit/test_audit_repository.py b/tests/unit/test_audit_repository.py new file mode 100644 index 0000000..01a16bb --- /dev/null +++ b/tests/unit/test_audit_repository.py @@ -0,0 +1,79 @@ +import asyncio +from datetime import datetime, timezone +from uuid import uuid4 + +from app.infra.db.metadb.repositories.audit_repository import AuditRepository +from tests.conftest import FakeAsyncSession, FakeExecuteResult, make_audit_log + + +def test_create_log_adds_commits_and_refreshes(monkeypatch): + class FakeAuditLog: + def __init__(self, **kwargs): + self.id = uuid4() + for key, value in kwargs.items(): + setattr(self, key, value) + + session = FakeAsyncSession() + repo = AuditRepository(session) + monkeypatch.setattr( + "app.infra.db.metadb.repositories.audit_repository.models.AuditLog", + FakeAuditLog, + ) + + result = asyncio.run( + repo.create_log( + action="LOGIN", + request_method="POST", + request_path="/auth/login", + response_status=200, + ) + ) + + assert result.action == "LOGIN" + assert result.request_method == "POST" + assert session.commit_count == 1 + assert len(session.added) == 1 + assert len(session.refreshed) == 1 + + +def test_get_logs_builds_filtered_query_and_returns_models(): + log = make_audit_log(action="UPDATE_USER", resource_type="user") + session = FakeAsyncSession( + execute_results=[FakeExecuteResult(rows=[log])], + ) + repo = AuditRepository(session) + user_id = uuid4() + project_id = uuid4() + start_time = datetime(2025, 1, 1, tzinfo=timezone.utc) + + results = asyncio.run( + repo.get_logs( + user_id=user_id, + project_id=project_id, + action="UPDATE_USER", + resource_type="user", + start_time=start_time, + skip=5, + limit=10, + ) + ) + + assert len(results) == 1 + assert results[0].action == "UPDATE_USER" + stmt = session.executed[0] + assert len(stmt._where_criteria) == 5 + assert stmt._offset == 5 + assert stmt._limit == 10 + + +def test_get_log_count_returns_zero_when_scalar_none(): + session = FakeAsyncSession( + execute_results=[FakeExecuteResult(scalar_value=None)], + ) + repo = AuditRepository(session) + + result = asyncio.run(repo.get_log_count(action="DELETE_USER")) + + assert result == 0 + stmt = session.executed[0] + assert len(stmt._where_criteria) == 1 diff --git a/tests/unit/test_auth_dependencies.py b/tests/unit/test_auth_dependencies.py new file mode 100644 index 0000000..0c556a7 --- /dev/null +++ b/tests/unit/test_auth_dependencies.py @@ -0,0 +1,97 @@ +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from fastapi import HTTPException + +from app.auth import dependencies +from app.core.security import create_access_token, create_refresh_token +from tests.conftest import make_user + + +def test_get_db_returns_app_state_db(): + request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace(db="db-instance"))) + + result = asyncio.run(dependencies.get_db(request)) + + assert result == "db-instance" + + +def test_get_db_raises_when_database_missing(): + request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace())) + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(dependencies.get_db(request)) + + assert exc_info.value.status_code == 503 + assert exc_info.value.detail == "Database not initialized" + + +def test_get_current_user_accepts_valid_access_token(): + repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=make_user())) + + result = asyncio.run( + dependencies.get_current_user( + token=create_access_token("tester"), + user_repo=repo, + ) + ) + + assert result.username == "tester" + repo.get_user_by_username.assert_awaited_once_with("tester") + + +def test_get_current_user_rejects_refresh_token(): + repo = SimpleNamespace(get_user_by_username=AsyncMock()) + + with pytest.raises(HTTPException) as exc_info: + asyncio.run( + dependencies.get_current_user( + token=create_refresh_token("tester"), + user_repo=repo, + ) + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid token type. Access token required." + repo.get_user_by_username.assert_not_awaited() + + +def test_get_current_user_rejects_missing_user(): + repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=None)) + + with pytest.raises(HTTPException) as exc_info: + asyncio.run( + dependencies.get_current_user( + token=create_access_token("ghost"), + user_repo=repo, + ) + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Could not validate credentials" + + +def test_get_current_active_user_rejects_inactive_user(): + with pytest.raises(HTTPException) as exc_info: + asyncio.run( + dependencies.get_current_active_user( + current_user=make_user(is_active=False), + ) + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Inactive user" + + +def test_get_current_superuser_rejects_non_superuser(): + with pytest.raises(HTTPException) as exc_info: + asyncio.run( + dependencies.get_current_superuser( + current_user=make_user(is_superuser=False), + ) + ) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Not enough privileges. Superuser access required." diff --git a/tests/unit/test_permissions.py b/tests/unit/test_permissions.py new file mode 100644 index 0000000..3dcd86a --- /dev/null +++ b/tests/unit/test_permissions.py @@ -0,0 +1,56 @@ +import asyncio +import pytest +from fastapi import HTTPException + +from app.auth import permissions +from app.domain.models.role import UserRole +from tests.conftest import make_user + + +def test_require_role_allows_higher_privilege_user(): + checker = permissions.require_role(UserRole.OPERATOR) + + result = asyncio.run(checker(current_user=make_user(role=UserRole.ADMIN))) + + assert result.role == UserRole.ADMIN + + +def test_require_role_rejects_insufficient_role(): + checker = permissions.require_role(UserRole.ADMIN) + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(checker(current_user=make_user(role=UserRole.USER))) + + assert exc_info.value.status_code == 403 + assert "Required role: ADMIN" in exc_info.value.detail + + +def test_check_resource_owner_allows_admin(): + assert permissions.check_resource_owner( + 99, + make_user(id=1, role=UserRole.ADMIN), + ) is True + + +def test_check_resource_owner_allows_owner(): + assert permissions.check_resource_owner( + 7, + make_user(id=7, role=UserRole.USER), + ) is True + + +def test_check_resource_owner_rejects_other_user(): + assert permissions.check_resource_owner( + 7, + make_user(id=8, role=UserRole.USER), + ) is False + + +def test_require_owner_or_admin_rejects_other_user(): + checker = permissions.require_owner_or_admin(7) + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(checker(current_user=make_user(id=8, role=UserRole.USER))) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "You don't have permission to access this resource" diff --git a/tests/unit/test_scada_repository.py b/tests/unit/test_scada_repository.py index 98fbcbb..c8a01e3 100644 --- a/tests/unit/test_scada_repository.py +++ b/tests/unit/test_scada_repository.py @@ -1,9 +1,8 @@ +import asyncio from datetime import datetime, timezone import importlib.util from pathlib import Path -import pytest - def _load_scada_repository(): module_path = ( @@ -50,18 +49,19 @@ class _FakeConnection: return self.cursor_instance -@pytest.mark.asyncio -async def test_update_scada_field_inserts_when_update_hits_no_rows(): +def test_update_scada_field_inserts_when_update_hits_no_rows(): ScadaRepository = _load_scada_repository() conn = _FakeConnection(initial_rowcount=0) point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc) - await ScadaRepository.update_scada_field( - conn, - point_time, - "170490", - "cleaned_value", - 26.5, + asyncio.run( + ScadaRepository.update_scada_field( + conn, + point_time, + "170490", + "cleaned_value", + 26.5, + ) ) assert len(conn.cursor_instance.calls) == 2 @@ -69,18 +69,19 @@ async def test_update_scada_field_inserts_when_update_hits_no_rows(): assert "INSERT INTO scada.scada_data" in conn.cursor_instance.calls[1][0] -@pytest.mark.asyncio -async def test_update_scada_field_skips_insert_when_update_succeeds(): +def test_update_scada_field_skips_insert_when_update_succeeds(): ScadaRepository = _load_scada_repository() conn = _FakeConnection(initial_rowcount=1) point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc) - await ScadaRepository.update_scada_field( - conn, - point_time, - "170490", - "cleaned_value", - 26.5, + asyncio.run( + ScadaRepository.update_scada_field( + conn, + point_time, + "170490", + "cleaned_value", + 26.5, + ) ) assert len(conn.cursor_instance.calls) == 1 diff --git a/tests/unit/test_user_repository.py b/tests/unit/test_user_repository.py new file mode 100644 index 0000000..7d52ad8 --- /dev/null +++ b/tests/unit/test_user_repository.py @@ -0,0 +1,124 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from app.domain.models.role import UserRole +from app.domain.schemas.user import UserCreate, UserUpdate +from app.infra.db.metadb.repositories.user_repository import UserRepository +from tests.conftest import FakeCursor, FakeDB + + +def _user_row(**overrides): + base = { + "id": 1, + "username": "tester", + "email": "tester@example.com", + "hashed_password": "hashed-password", + "role": "USER", + "is_active": True, + "is_superuser": False, + "created_at": "2025-01-01T00:00:00+00:00", + "updated_at": "2025-01-01T00:00:00+00:00", + } + base.update(overrides) + return base + + +def test_create_user_hashes_password_and_returns_model(monkeypatch): + cursor = FakeCursor(fetchone_results=[_user_row()]) + repo = UserRepository(FakeDB(cursor)) + monkeypatch.setattr( + "app.infra.db.metadb.repositories.user_repository.get_password_hash", + lambda password: f"hashed::{password}", + ) + + result = asyncio.run( + repo.create_user( + UserCreate( + username="tester", + email="tester@example.com", + password="secret123", + ) + ) + ) + + assert result is not None + assert result.username == "tester" + assert cursor.executed[0][1]["hashed_password"] == "hashed::secret123" + + +def test_update_user_without_fields_returns_existing_user(monkeypatch): + repo = UserRepository(FakeDB(FakeCursor())) + existing_user = AsyncMock(return_value="existing") + monkeypatch.setattr(repo, "get_user_by_id", existing_user) + + result = asyncio.run(repo.update_user(1, UserUpdate())) + + assert result == "existing" + existing_user.assert_awaited_once_with(1) + + +def test_update_user_builds_dynamic_query(monkeypatch): + cursor = FakeCursor(fetchone_results=[_user_row(role="ADMIN", email="new@example.com")]) + repo = UserRepository(FakeDB(cursor)) + monkeypatch.setattr( + "app.infra.db.metadb.repositories.user_repository.get_password_hash", + lambda password: f"hashed::{password}", + ) + + result = asyncio.run( + repo.update_user( + 1, + UserUpdate( + email="new@example.com", + password="new-secret", + role=UserRole.ADMIN, + is_active=False, + ), + ), + ) + + assert result is not None + query, params = cursor.executed[0] + assert "email = %(email)s" in query + assert "hashed_password = %(hashed_password)s" in query + assert "role = %(role)s" in query + assert "is_active = %(is_active)s" in query + assert params["hashed_password"] == "hashed::new-secret" + assert params["role"] == "ADMIN" + assert params["is_active"] is False + + +def test_delete_user_returns_false_when_execute_raises(): + cursor = FakeCursor() + cursor.execute = AsyncMock(side_effect=RuntimeError("boom")) + repo = UserRepository(FakeDB(cursor)) + + result = asyncio.run(repo.delete_user(1)) + + assert result is False + + +def test_user_exists_short_circuits_without_filters(): + cursor = FakeCursor() + repo = UserRepository(FakeDB(cursor)) + + result = asyncio.run(repo.user_exists()) + + assert result is False + assert cursor.executed == [] + + +def test_user_exists_checks_username_or_email(): + cursor = FakeCursor(fetchone_results=[{"exists": True}]) + repo = UserRepository(FakeDB(cursor)) + + result = asyncio.run( + repo.user_exists(username="tester", email="tester@example.com") + ) + + assert result is True + query, params = cursor.executed[0] + assert "username = %(username)s OR email = %(email)s" in query + assert params == {"username": "tester", "email": "tester@example.com"} diff --git a/失效API排查.md b/失效API排查.md new file mode 100644 index 0000000..8233525 --- /dev/null +++ b/失效API排查.md @@ -0,0 +1,76 @@ +# `app/api/v1/endpoints/` 失效 API 排查与修正 + +排查范围:`app/api/v1/endpoints/` + +结论:本次共确认 5 个问题接口,处理结果如下: + +- **已删除 4 个未实现坏接口** +- **已修正 1 个签名失配接口** + +> 路由统一前缀来自 `app/main.py:71`,以下完整路径均以 `/api/v1` 开头。 + +## 处理结果 + +| Method | API | 原问题 | 处理结果 | +| --- | --- | --- | --- | +| GET | `/api/v1/calculateregion/` | 调用时 `NameError`,底层无 `calculate_region` 实现 | **已删除** | +| GET | `/api/v1/getallregions/` | 调用时 `NameError`,底层无 `get_all_regions` 实现 | **已删除** | +| POST | `/api/v1/generateregion/` | 调用时 `NameError`,底层无 `generate_region` 实现 | **已删除** | +| GET | `/api/v1/calculatedistrictmeteringarea/` | 调用时 `NameError`,仍指向已废弃旧 DMA 入口 | **已删除** | +| GET | `/api/v1/calculateservicearea/` | endpoint 传 `time_index`,实现只接受 `name` | **已修正**,现返回全部时间步结果 | + +## 删除原因 + +### 1. region 相关 3 个接口 + +以下能力在当前 `wndb` / `tjnetwork` 中均不存在: + +- `calculate_region` +- `get_all_regions` +- `generate_region` + +`app/native/wndb/__init__.py` 当前只提供 region CRUD 和 util 能力,不提供 region 计算或批量查询能力。因此这 3 个接口继续保留只会在运行时失败。 + +### 2. DMA 旧入口 + +旧接口 `calculate_district_metering_area(...)` 已不存在,当前只保留 3 个明确变体: + +- `/api/v1/calculatedistrictmeteringareafornodes/` +- `/api/v1/calculatedistrictmeteringareaforregion/` +- `/api/v1/calculatedistrictmeteringareafornetwork/` + +因此旧入口 `/api/v1/calculatedistrictmeteringarea/` 已删除,避免前端继续误用历史接口。 + +## 修正内容 + +### `GET /api/v1/calculateservicearea/` + +原接口问题: + +- endpoint 定义保留 `time_index` +- 实际实现 `calculate_service_area(name)` 只接收 `network/name` +- 调用时会触发参数数量不匹配 + +本次修正后: + +- 移除 `time_index` 查询参数 +- 返回类型改为 `list[dict[str, list[str]]]` +- 接口语义改为:**返回全部时间步的服务区计算结果** + +## 当前可用替代接口 + +DMA 计算请使用: + +- `/api/v1/calculatedistrictmeteringareafornodes/` +- `/api/v1/calculatedistrictmeteringareaforregion/` +- `/api/v1/calculatedistrictmeteringareafornetwork/` + +服务区计算请使用: + +- `/api/v1/calculateservicearea/` + 现在返回全部时间步结果,不再接收 `time_index` + +## 变更文件 + +- `app/api/v1/endpoints/network/regions.py` +- `失效API排查.md`