新增 API 测试用例,修复失效接口问题
This commit is contained in:
@@ -7,11 +7,9 @@ from app.services.tjnetwork import (
|
||||
add_region,
|
||||
add_service_area,
|
||||
add_virtual_district,
|
||||
# calculate_district_metering_area,
|
||||
calculate_district_metering_area_for_network,
|
||||
calculate_district_metering_area_for_nodes,
|
||||
calculate_district_metering_area_for_region,
|
||||
# calculate_region,
|
||||
calculate_service_area,
|
||||
calculate_virtual_district,
|
||||
delete_district_metering_area,
|
||||
@@ -19,13 +17,11 @@ from app.services.tjnetwork import (
|
||||
delete_service_area,
|
||||
delete_virtual_district,
|
||||
generate_district_metering_area,
|
||||
# generate_region,
|
||||
generate_service_area,
|
||||
generate_sub_district_metering_area,
|
||||
generate_virtual_district,
|
||||
get_all_district_metering_area_ids,
|
||||
get_all_district_metering_areas,
|
||||
# get_all_regions,
|
||||
get_all_service_areas,
|
||||
get_all_virtual_districts,
|
||||
get_district_metering_area,
|
||||
@@ -48,18 +44,6 @@ router = APIRouter()
|
||||
# region 32
|
||||
############################################################
|
||||
|
||||
@router.get(
|
||||
"/calculateregion/",
|
||||
summary="计算区域",
|
||||
description="计算指定水网在指定时间步长的区域分区"
|
||||
)
|
||||
async def fastapi_calculate_region(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
time_index: int = Query(..., description="时间步长索引", ge=0)
|
||||
) -> dict[str, Any]:
|
||||
"""计算区域分区。"""
|
||||
return calculate_region(network, time_index)
|
||||
|
||||
@router.get(
|
||||
"/getregionschema/",
|
||||
summary="获取区域属性架构",
|
||||
@@ -125,62 +109,11 @@ async def fastapi_delete_region(
|
||||
props = await req.json()
|
||||
return delete_region(network, ChangeSet(props))
|
||||
|
||||
@router.get(
|
||||
"/getallregions/",
|
||||
summary="获取所有区域",
|
||||
description="获取指定水网中的所有区域信息"
|
||||
)
|
||||
async def fastapi_get_all_regions(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)")
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取所有区域的信息列表。"""
|
||||
return get_all_regions(network)
|
||||
|
||||
@router.post(
|
||||
"/generateregion/",
|
||||
response_model=None,
|
||||
summary="生成区域分区",
|
||||
description="根据参数自动生成水网的区域分区"
|
||||
)
|
||||
async def fastapi_generate_region(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
inflate_delta: float = Query(..., description="膨胀参数")
|
||||
) -> ChangeSet:
|
||||
"""生成区域分区。"""
|
||||
return generate_region(network, inflate_delta)
|
||||
|
||||
|
||||
############################################################
|
||||
# district_metering_area 33
|
||||
############################################################
|
||||
|
||||
@router.get(
|
||||
"/calculatedistrictmeteringarea/",
|
||||
summary="计算DMA分区",
|
||||
description="计算指定节点集的区域计量(DMA)分区方案"
|
||||
)
|
||||
async def fastapi_calculate_district_metering_area(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
req: Request = None
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
计算DMA分区。
|
||||
|
||||
请求体格式:
|
||||
{
|
||||
"nodes": 节点ID列表(list[str]),
|
||||
"part_count": 分区数量(int),
|
||||
"part_type": 分区类型(int)
|
||||
}
|
||||
"""
|
||||
props = await req.json()
|
||||
nodes = props["nodes"]
|
||||
part_count = props["part_count"]
|
||||
part_type = props["part_type"]
|
||||
return calculate_district_metering_area(
|
||||
network, nodes, part_count, part_type
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/calculatedistrictmeteringareaforregion/",
|
||||
summary="计算区域内DMA分区",
|
||||
@@ -368,14 +301,13 @@ async def fastapi_generate_sub_district_metering_area(
|
||||
@router.get(
|
||||
"/calculateservicearea/",
|
||||
summary="计算服务区",
|
||||
description="计算指定水网在指定时间步长的服务区分区"
|
||||
description="计算指定水网的服务区分区,返回全部时间步结果"
|
||||
)
|
||||
async def fastapi_calculate_service_area(
|
||||
network: str = Query(..., description="管网名称(或数据库名称)"),
|
||||
time_index: int = Query(..., description="时间步长索引", ge=0)
|
||||
) -> dict[str, Any]:
|
||||
"""计算服务区分区。"""
|
||||
return calculate_service_area(network, time_index)
|
||||
) -> list[dict[str, list[str]]]:
|
||||
"""计算服务区分区,返回全部时间步结果。"""
|
||||
return calculate_service_area(network)
|
||||
|
||||
@router.get(
|
||||
"/getserviceareaschema/",
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints import audit as audit_endpoint
|
||||
from app.auth.metadata_dependencies import (
|
||||
get_current_metadata_admin,
|
||||
get_current_metadata_user,
|
||||
)
|
||||
from tests.conftest import build_test_app, make_audit_log
|
||||
|
||||
|
||||
def _build_client(repo, *, metadata_admin=None, metadata_user=None) -> TestClient:
|
||||
app = build_test_app(audit_endpoint.router, "/audit")
|
||||
app.dependency_overrides[audit_endpoint.get_audit_repository] = lambda: repo
|
||||
if metadata_admin is not None:
|
||||
app.dependency_overrides[get_current_metadata_admin] = lambda: metadata_admin
|
||||
if metadata_user is not None:
|
||||
app.dependency_overrides[get_current_metadata_user] = lambda: metadata_user
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_get_audit_logs_passes_filters():
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_logs": AsyncMock(return_value=[make_audit_log(action="LOGIN")]),
|
||||
"get_log_count": AsyncMock(),
|
||||
},
|
||||
)()
|
||||
client = _build_client(repo, metadata_admin=object())
|
||||
|
||||
response = client.get(
|
||||
"/audit/logs",
|
||||
params={
|
||||
"action": "LOGIN",
|
||||
"resource_type": "user",
|
||||
"skip": 2,
|
||||
"limit": 5,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()[0]["action"] == "LOGIN"
|
||||
repo.get_logs.assert_awaited_once()
|
||||
kwargs = repo.get_logs.await_args.kwargs
|
||||
assert kwargs["action"] == "LOGIN"
|
||||
assert kwargs["resource_type"] == "user"
|
||||
assert kwargs["skip"] == 2
|
||||
assert kwargs["limit"] == 5
|
||||
|
||||
|
||||
def test_get_audit_logs_count_returns_count_payload():
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_logs": AsyncMock(),
|
||||
"get_log_count": AsyncMock(return_value=7),
|
||||
},
|
||||
)()
|
||||
client = _build_client(repo, metadata_admin=object())
|
||||
|
||||
response = client.get("/audit/logs/count", params={"action": "DELETE_USER"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"count": 7}
|
||||
repo.get_log_count.assert_awaited_once()
|
||||
assert repo.get_log_count.await_args.kwargs["action"] == "DELETE_USER"
|
||||
|
||||
|
||||
def test_get_my_audit_logs_forces_current_user_id():
|
||||
current_user = type("User", (), {"id": make_audit_log().user_id})()
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_logs": AsyncMock(return_value=[make_audit_log(user_id=current_user.id)]),
|
||||
"get_log_count": AsyncMock(),
|
||||
},
|
||||
)()
|
||||
client = _build_client(repo, metadata_user=current_user)
|
||||
|
||||
response = client.get("/audit/logs/my", params={"limit": 3})
|
||||
|
||||
assert response.status_code == 200
|
||||
repo.get_logs.assert_awaited_once()
|
||||
kwargs = repo.get_logs.await_args.kwargs
|
||||
assert kwargs["user_id"] == current_user.id
|
||||
assert kwargs["limit"] == 3
|
||||
@@ -0,0 +1,139 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints import auth as auth_endpoint
|
||||
from app.auth.dependencies import get_current_active_user, get_user_repository
|
||||
from app.core.security import create_access_token, create_refresh_token, get_password_hash
|
||||
from tests.conftest import build_test_app, make_user
|
||||
|
||||
|
||||
def _build_client(repo, current_user=None) -> TestClient:
|
||||
app = build_test_app(auth_endpoint.router, "/api/v1/auth")
|
||||
app.dependency_overrides[get_user_repository] = lambda: repo
|
||||
if current_user is not None:
|
||||
app.dependency_overrides[get_current_active_user] = lambda: current_user
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_register_success():
|
||||
repo = SimpleNamespace(
|
||||
user_exists=AsyncMock(side_effect=[False, False]),
|
||||
create_user=AsyncMock(return_value=make_user()),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json()["username"] == "tester"
|
||||
|
||||
|
||||
def test_register_rejects_duplicate_username():
|
||||
repo = SimpleNamespace(
|
||||
user_exists=AsyncMock(side_effect=[True]),
|
||||
create_user=AsyncMock(),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"password": "secret123",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Username already registered"
|
||||
repo.create_user.assert_not_awaited()
|
||||
|
||||
|
||||
def test_login_supports_email_lookup():
|
||||
hashed_password = get_password_hash("secret123")
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_username=AsyncMock(return_value=None),
|
||||
get_user_by_email=AsyncMock(
|
||||
return_value=make_user(
|
||||
email="tester@example.com",
|
||||
hashed_password=hashed_password,
|
||||
)
|
||||
),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login",
|
||||
data={"username": "tester@example.com", "password": "secret123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
repo.get_user_by_email.assert_awaited_once_with("tester@example.com")
|
||||
|
||||
|
||||
def test_login_simple_uses_query_params():
|
||||
hashed_password = get_password_hash("secret123")
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_username=AsyncMock(
|
||||
return_value=make_user(hashed_password=hashed_password)
|
||||
),
|
||||
get_user_by_email=AsyncMock(),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/login/simple",
|
||||
params={"username": "tester", "password": "secret123"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["token_type"] == "bearer"
|
||||
|
||||
|
||||
def test_me_returns_current_user_info():
|
||||
client = _build_client(SimpleNamespace(), current_user=make_user(username="alice"))
|
||||
|
||||
response = client.get("/api/v1/auth/me")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["username"] == "alice"
|
||||
|
||||
|
||||
def test_refresh_rejects_access_token():
|
||||
repo = SimpleNamespace(get_user_by_username=AsyncMock())
|
||||
client = _build_client(repo)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
params={"refresh_token": create_access_token("tester")},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_refresh_success_returns_new_access_token():
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_username=AsyncMock(return_value=make_user()),
|
||||
)
|
||||
client = _build_client(repo)
|
||||
refresh_token = create_refresh_token("tester")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
params={"refresh_token": refresh_token},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["refresh_token"] == refresh_token
|
||||
assert payload["token_type"] == "bearer"
|
||||
@@ -0,0 +1,152 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tests.conftest import build_test_app, install_stub, load_module_from_path
|
||||
|
||||
|
||||
class DummyChangeSet:
|
||||
def __init__(self, operations=None):
|
||||
if operations is None:
|
||||
self.operations = []
|
||||
elif isinstance(operations, dict):
|
||||
self.operations = [operations]
|
||||
else:
|
||||
self.operations = operations
|
||||
|
||||
|
||||
def _load_project_module(monkeypatch):
|
||||
install_stub(monkeypatch, "app.services", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.project_info",
|
||||
{},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.tjnetwork",
|
||||
{
|
||||
"ChangeSet": DummyChangeSet,
|
||||
"list_project": lambda: ["demo"],
|
||||
"have_project": lambda network: network == "demo",
|
||||
"create_project": lambda network: None,
|
||||
"delete_project": lambda network: None,
|
||||
"is_project_open": lambda network: False,
|
||||
"open_project": lambda network: None,
|
||||
"close_project": lambda network: None,
|
||||
"copy_project": lambda source, target: None,
|
||||
"import_inp": lambda network, cs: {"ok": True},
|
||||
"export_inp": lambda network, version: DummyChangeSet({"kind": "export"}),
|
||||
"read_inp": lambda network, inp: True,
|
||||
"dump_inp": lambda network, inp: True,
|
||||
"get_all_vertices": lambda network: [],
|
||||
"get_all_scada_elements": lambda network: [],
|
||||
"get_all_district_metering_areas": lambda network: [],
|
||||
"get_all_service_areas": lambda network: [],
|
||||
"get_all_virtual_districts": lambda network: [],
|
||||
"get_extension_data": lambda network, key: None,
|
||||
"convert_inp_v3_to_v2": lambda inp: DummyChangeSet({"inp": inp}),
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.auth.project_dependencies",
|
||||
{"get_metadata_repository": lambda: None},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.infra.db.postgresql.database",
|
||||
{"get_database_instance": lambda network: None},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.infra.db.timescaledb.database",
|
||||
{"get_database_instance": lambda network: None},
|
||||
)
|
||||
return load_module_from_path(
|
||||
"tests_project_endpoints_module",
|
||||
"app/api/v1/endpoints/project.py",
|
||||
)
|
||||
|
||||
|
||||
def test_project_info_returns_404_when_missing(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=None))
|
||||
app = build_test_app(module.router, "/api/v1")
|
||||
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/project_info/", params={"network": "missing"})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Project missing not found"
|
||||
|
||||
|
||||
def test_project_info_returns_geoserver_payload(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
detail = SimpleNamespace(
|
||||
project_id=uuid4(),
|
||||
name="Demo Project",
|
||||
code="demo",
|
||||
description="desc",
|
||||
gs_workspace="ws",
|
||||
map_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
|
||||
status="active",
|
||||
geoserver=SimpleNamespace(
|
||||
gs_base_url="http://gs",
|
||||
gs_admin_user="admin",
|
||||
gs_datastore_name="store",
|
||||
default_extent={"xmin": 1, "ymin": 2, "xmax": 3, "ymax": 4},
|
||||
srid=4326,
|
||||
),
|
||||
)
|
||||
repo = SimpleNamespace(get_project_detail_by_code=AsyncMock(return_value=detail))
|
||||
app = build_test_app(module.router, "/api/v1")
|
||||
app.dependency_overrides[module.get_metadata_repository] = lambda: repo
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/project_info/", params={"network": "demo"})
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["code"] == "demo"
|
||||
assert payload["geoserver"]["gs_base_url"] == "http://gs"
|
||||
|
||||
|
||||
def test_open_project_returns_network_even_when_db_connection_fails(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
called = []
|
||||
|
||||
monkeypatch.setattr(module, "open_project", lambda network: called.append(network))
|
||||
|
||||
async def failing_get_pg_db(network):
|
||||
raise RuntimeError("db down")
|
||||
|
||||
monkeypatch.setattr(module, "get_pg_db", failing_get_pg_db)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post("/api/v1/openproject/", params={"network": "demo"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "demo"
|
||||
assert called == ["demo"]
|
||||
|
||||
|
||||
def test_project_lock_lifecycle(monkeypatch):
|
||||
module = _load_project_module(monkeypatch)
|
||||
module.lockedPrjs.clear()
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
first_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
|
||||
second_lock = client.post("/api/v1/lockproject/", params={"network": "demo"})
|
||||
locked_by_me = client.get("/api/v1/isprojectlockedbyme/", params={"network": "demo"})
|
||||
unlock = client.post("/api/v1/unlockproject/", params={"network": "demo"})
|
||||
locked = client.get("/api/v1/isprojectlocked/", params={"network": "demo"})
|
||||
|
||||
assert first_lock.json() == 0
|
||||
assert second_lock.json() == 1
|
||||
assert locked_by_me.json() is True
|
||||
assert unlock.json() is True
|
||||
assert locked.json() is False
|
||||
@@ -0,0 +1,154 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tests.conftest import build_test_app, install_stub, load_module_from_path
|
||||
|
||||
|
||||
class DummyChangeSet:
|
||||
def __init__(self, operations=None):
|
||||
if operations is None:
|
||||
self.operations = []
|
||||
elif isinstance(operations, dict):
|
||||
self.operations = [operations]
|
||||
else:
|
||||
self.operations = operations
|
||||
|
||||
|
||||
def _noop(*args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def _load_regions_module(monkeypatch):
|
||||
install_stub(monkeypatch, "app.services", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.tjnetwork",
|
||||
{
|
||||
"Any": Any,
|
||||
"ChangeSet": DummyChangeSet,
|
||||
"add_district_metering_area": _noop,
|
||||
"add_region": _noop,
|
||||
"add_service_area": _noop,
|
||||
"add_virtual_district": _noop,
|
||||
"calculate_district_metering_area_for_network": lambda *args, **kwargs: [],
|
||||
"calculate_district_metering_area_for_nodes": lambda *args, **kwargs: [],
|
||||
"calculate_district_metering_area_for_region": lambda *args, **kwargs: [],
|
||||
"calculate_service_area": lambda network: [],
|
||||
"calculate_virtual_district": lambda *args, **kwargs: {},
|
||||
"delete_district_metering_area": _noop,
|
||||
"delete_region": _noop,
|
||||
"delete_service_area": _noop,
|
||||
"delete_virtual_district": _noop,
|
||||
"generate_district_metering_area": _noop,
|
||||
"generate_service_area": _noop,
|
||||
"generate_sub_district_metering_area": _noop,
|
||||
"generate_virtual_district": _noop,
|
||||
"get_all_district_metering_area_ids": lambda network: [],
|
||||
"get_all_district_metering_areas": lambda network: [],
|
||||
"get_all_service_areas": lambda network: [],
|
||||
"get_all_virtual_districts": lambda network: [],
|
||||
"get_district_metering_area": lambda network, area_id: {},
|
||||
"get_district_metering_area_schema": lambda network: {},
|
||||
"get_region": lambda network, region_id: {},
|
||||
"get_region_schema": lambda network: {},
|
||||
"get_service_area": lambda network, area_id: {},
|
||||
"get_service_area_schema": lambda network: {},
|
||||
"get_virtual_district": lambda network, area_id: {},
|
||||
"get_virtual_district_schema": lambda network: {},
|
||||
"set_district_metering_area": _noop,
|
||||
"set_region": _noop,
|
||||
"set_service_area": _noop,
|
||||
"set_virtual_district": _noop,
|
||||
},
|
||||
)
|
||||
return load_module_from_path(
|
||||
"tests_regions_endpoints_module",
|
||||
"app/api/v1/endpoints/network/regions.py",
|
||||
)
|
||||
|
||||
|
||||
def test_removed_routes_are_absent_and_return_404(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
openapi = client.get("/openapi.json").json()
|
||||
|
||||
assert "/api/v1/calculateregion/" not in openapi["paths"]
|
||||
assert "/api/v1/getallregions/" not in openapi["paths"]
|
||||
assert "/api/v1/generateregion/" not in openapi["paths"]
|
||||
assert "/api/v1/calculatedistrictmeteringarea/" not in openapi["paths"]
|
||||
assert client.get("/api/v1/calculateregion/", params={"network": "demo", "time_index": 0}).status_code == 404
|
||||
assert client.get("/api/v1/calculatedistrictmeteringarea/", params={"network": "demo"}).status_code == 404
|
||||
|
||||
|
||||
def test_calculate_service_area_contract_uses_only_network(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"calculate_service_area",
|
||||
lambda network: calls.append(network) or [{"source-1": ["n1", "n2"]}],
|
||||
)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/calculateservicearea/",
|
||||
params={"network": "demo", "time_index": 5},
|
||||
)
|
||||
schema = client.get("/openapi.json").json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [{"source-1": ["n1", "n2"]}]
|
||||
assert calls == ["demo"]
|
||||
parameter_names = [
|
||||
item["name"]
|
||||
for item in schema["paths"]["/api/v1/calculateservicearea/"]["get"]["parameters"]
|
||||
]
|
||||
assert parameter_names == ["network"]
|
||||
|
||||
|
||||
def test_add_district_metering_area_converts_boundary_to_tuples(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_add(network, change_set):
|
||||
captured["network"] = network
|
||||
captured["boundary"] = change_set.operations[0]["boundary"]
|
||||
return {"ok": True}
|
||||
|
||||
monkeypatch.setattr(module, "add_district_metering_area", fake_add)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/adddistrictmeteringarea/",
|
||||
params={"network": "demo"},
|
||||
json={"id": "dma-1", "boundary": [[1, 2], [3, 4], [1, 2]]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert captured == {
|
||||
"network": "demo",
|
||||
"boundary": [(1, 2), (3, 4), (1, 2)],
|
||||
}
|
||||
|
||||
|
||||
def test_generate_virtual_district_reads_centers_from_body(monkeypatch):
|
||||
module = _load_regions_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_generate(network, centers, inflate_delta):
|
||||
captured["args"] = (network, centers, inflate_delta)
|
||||
return {"generated": True}
|
||||
|
||||
monkeypatch.setattr(module, "generate_virtual_district", fake_generate)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/generatevirtualdistrict/",
|
||||
params={"network": "demo", "inflate_delta": 0.75},
|
||||
json={"centers": ["J1", "J2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert captured["args"] == ("demo", ["J1", "J2"], 0.75)
|
||||
@@ -0,0 +1,175 @@
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from tests.conftest import build_test_app, install_stub, load_module_from_path
|
||||
|
||||
|
||||
def _load_simulation_module(monkeypatch):
|
||||
install_stub(monkeypatch, "app.services", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.simulation",
|
||||
{"run_simulation": lambda **kwargs: None},
|
||||
)
|
||||
install_stub(monkeypatch, "app.services.globals", {})
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.tjnetwork",
|
||||
{
|
||||
"run_project": lambda network: "report",
|
||||
"run_project_return_dict": lambda network: {"output": {}, "report": "ok"},
|
||||
"run_inp": lambda network: "inp-report",
|
||||
"dump_output": lambda output: f"dump::{output}",
|
||||
},
|
||||
)
|
||||
install_stub(monkeypatch, "app.algorithms", package=True)
|
||||
install_stub(monkeypatch, "app.algorithms.simulation", package=True)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.algorithms.simulation.scenarios",
|
||||
{
|
||||
"burst_analysis": lambda *args, **kwargs: "burst",
|
||||
"valve_close_analysis": lambda *args, **kwargs: "valve",
|
||||
"flushing_analysis": lambda *args, **kwargs: "flush",
|
||||
"contaminant_simulation": lambda *args, **kwargs: "contaminant",
|
||||
"age_analysis": lambda *args, **kwargs: "age",
|
||||
"pressure_regulation": lambda *args, **kwargs: "pressure",
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.algorithms.sensor",
|
||||
{
|
||||
"pressure_sensor_placement_sensitivity": lambda *args, **kwargs: [],
|
||||
"pressure_sensor_placement_kmeans": lambda *args, **kwargs: [],
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.network_import",
|
||||
{"network_update": lambda *args, **kwargs: "updated"},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.simulation_ops",
|
||||
{
|
||||
"project_management": lambda *args, **kwargs: "managed",
|
||||
"scheduling_simulation": lambda *args, **kwargs: "scheduled",
|
||||
"daily_scheduling_simulation": lambda *args, **kwargs: "daily",
|
||||
},
|
||||
)
|
||||
install_stub(
|
||||
monkeypatch,
|
||||
"app.services.valve_isolation",
|
||||
{"analyze_valve_isolation": lambda *args, **kwargs: {}},
|
||||
)
|
||||
return load_module_from_path(
|
||||
"tests_simulation_endpoints_module",
|
||||
"app/api/v1/endpoints/simulation.py",
|
||||
)
|
||||
|
||||
|
||||
def test_run_project_endpoint_returns_plain_text(monkeypatch):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
monkeypatch.setattr(module, "run_project", lambda network: f"report::{network}")
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.get("/api/v1/runproject/", params={"network": "demo"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == "report::demo"
|
||||
|
||||
|
||||
def test_scheduling_analysis_maps_request_body(monkeypatch):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_schedule(network, start_time, pump_control, tank_id, water_plant_output_id, time_delta):
|
||||
captured["args"] = (
|
||||
network,
|
||||
start_time,
|
||||
pump_control,
|
||||
tank_id,
|
||||
water_plant_output_id,
|
||||
time_delta,
|
||||
)
|
||||
return "scheduled"
|
||||
|
||||
monkeypatch.setattr(module, "scheduling_simulation", fake_schedule)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/scheduling_analysis/",
|
||||
json={
|
||||
"network": "demo",
|
||||
"start_time": "2025-01-01T08:00:00+08:00",
|
||||
"pump_control": {"P1": [1, 0, 1]},
|
||||
"tank_id": "T1",
|
||||
"water_plant_output_id": "R1",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "scheduled"
|
||||
assert captured["args"] == (
|
||||
"demo",
|
||||
"2025-01-01T08:00:00+08:00",
|
||||
{"P1": [1, 0, 1]},
|
||||
"T1",
|
||||
"R1",
|
||||
300,
|
||||
)
|
||||
|
||||
|
||||
def test_project_management_maps_named_arguments(monkeypatch):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
captured = {}
|
||||
|
||||
def fake_project_management(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return "managed"
|
||||
|
||||
monkeypatch.setattr(module, "project_management", fake_project_management)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/project_management/",
|
||||
json={
|
||||
"network": "demo",
|
||||
"start_time": "2025-01-01T08:00:00+08:00",
|
||||
"pump_control": {"P1": [1]},
|
||||
"tank_init_level": {"T1": 10.0},
|
||||
"region_demand": {"R1": 20.0},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == "managed"
|
||||
assert captured == {
|
||||
"prj_name": "demo",
|
||||
"start_datetime": "2025-01-01T08:00:00+08:00",
|
||||
"pump_control": {"P1": [1]},
|
||||
"tank_initial_level_control": {"T1": 10.0},
|
||||
"region_demand_control": {"R1": 20.0},
|
||||
}
|
||||
|
||||
|
||||
def test_network_update_surfaces_service_error(monkeypatch, tmp_path):
|
||||
module = _load_simulation_module(monkeypatch)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
def boom(_path):
|
||||
raise RuntimeError("write failed")
|
||||
|
||||
monkeypatch.setattr(module, "network_update", boom)
|
||||
client = TestClient(build_test_app(module.router, "/api/v1"))
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/network_update/",
|
||||
files={"file": ("update.txt", b"payload")},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "数据库操作失败: write failed" in response.json()["detail"]
|
||||
assert list(Path(tmp_path).glob("network_update_*"))
|
||||
@@ -0,0 +1,95 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.api.v1.endpoints import user_management as user_management_endpoint
|
||||
from app.auth.dependencies import get_current_active_user, get_user_repository
|
||||
from app.auth.permissions import get_current_admin
|
||||
from app.domain.models.role import UserRole
|
||||
from tests.conftest import build_test_app, make_user
|
||||
|
||||
|
||||
def _build_client(repo, *, current_user=None, admin_user=None) -> TestClient:
|
||||
app = build_test_app(user_management_endpoint.router, "/users")
|
||||
app.dependency_overrides[get_user_repository] = lambda: repo
|
||||
if current_user is not None:
|
||||
app.dependency_overrides[get_current_active_user] = lambda: current_user
|
||||
if admin_user is not None:
|
||||
app.dependency_overrides[get_current_admin] = lambda: admin_user
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_list_users_requires_admin_role():
|
||||
repo = SimpleNamespace(
|
||||
get_all_users=AsyncMock(
|
||||
return_value=[
|
||||
make_user(id=1, username="admin", role=UserRole.ADMIN),
|
||||
make_user(id=2, username="user2"),
|
||||
]
|
||||
)
|
||||
)
|
||||
client = _build_client(
|
||||
repo,
|
||||
current_user=make_user(id=1, role=UserRole.ADMIN),
|
||||
)
|
||||
|
||||
response = client.get("/users/", params={"skip": 5, "limit": 2})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
repo.get_all_users.assert_awaited_once_with(skip=5, limit=2)
|
||||
|
||||
|
||||
def test_get_user_rejects_non_owner_non_admin():
|
||||
repo = SimpleNamespace(get_user_by_id=AsyncMock())
|
||||
client = _build_client(repo, current_user=make_user(id=2, role=UserRole.USER))
|
||||
|
||||
response = client.get("/users/3")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "You don't have permission to view this user"
|
||||
repo.get_user_by_id.assert_not_awaited()
|
||||
|
||||
|
||||
def test_update_user_blocks_role_change_for_non_admin():
|
||||
repo = SimpleNamespace(
|
||||
get_user_by_id=AsyncMock(return_value=make_user(id=1)),
|
||||
update_user=AsyncMock(),
|
||||
)
|
||||
client = _build_client(repo, current_user=make_user(id=1, role=UserRole.USER))
|
||||
|
||||
response = client.put("/users/1", json={"role": "ADMIN"})
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "Only admins can change user roles"
|
||||
repo.update_user.assert_not_awaited()
|
||||
|
||||
|
||||
def test_delete_user_blocks_self_delete_for_admin():
|
||||
admin_user = make_user(id=1, role=UserRole.ADMIN, is_superuser=True)
|
||||
repo = SimpleNamespace(delete_user=AsyncMock())
|
||||
client = _build_client(repo, admin_user=admin_user)
|
||||
|
||||
response = client.delete("/users/1")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "You cannot delete your own account"
|
||||
repo.delete_user.assert_not_awaited()
|
||||
|
||||
|
||||
def test_activate_user_updates_active_flag():
|
||||
repo = SimpleNamespace(
|
||||
update_user=AsyncMock(return_value=make_user(id=2, is_active=True)),
|
||||
)
|
||||
client = _build_client(
|
||||
repo,
|
||||
admin_user=make_user(id=1, role=UserRole.ADMIN, is_superuser=True),
|
||||
)
|
||||
|
||||
response = client.post("/users/2/activate")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["is_active"] is True
|
||||
user_update = repo.update_user.await_args.args[1]
|
||||
assert user_update.is_active is True
|
||||
@@ -0,0 +1,36 @@
|
||||
from jose import jwt
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
|
||||
def test_password_hash_roundtrip():
|
||||
hashed = get_password_hash("secret123")
|
||||
assert hashed != "secret123"
|
||||
assert verify_password("secret123", hashed) is True
|
||||
assert verify_password("wrong", hashed) is False
|
||||
|
||||
|
||||
def test_create_access_token_sets_access_type():
|
||||
token = create_access_token("alice")
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
|
||||
assert payload["sub"] == "alice"
|
||||
assert payload["type"] == "access"
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
|
||||
|
||||
def test_create_refresh_token_sets_refresh_type():
|
||||
token = create_refresh_token("alice")
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
|
||||
assert payload["sub"] == "alice"
|
||||
assert payload["type"] == "refresh"
|
||||
assert "exp" in payload
|
||||
assert "iat" in payload
|
||||
+189
-6
@@ -1,14 +1,197 @@
|
||||
import pytest
|
||||
import sys
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
# 自动添加项目根目录到路径(处理项目结构)
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
def run_this_test(test_file):
|
||||
"""自定义函数:运行单个测试文件(类似pytest)"""
|
||||
# 提取测试文件名(无扩展名)
|
||||
test_name = os.path.splitext(os.path.basename(test_file))[0]
|
||||
# 使用pytest运行(自动处理导入)
|
||||
pytest.main([test_file, "-v"])
|
||||
|
||||
|
||||
def build_test_app(router, prefix: str = "") -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix=prefix)
|
||||
return app
|
||||
|
||||
|
||||
def load_module_from_path(module_name: str, relative_path: str):
|
||||
module_path = PROJECT_ROOT / relative_path
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def install_stub(monkeypatch, name: str, attrs: dict | None = None, package: bool = False):
|
||||
module = types.ModuleType(name)
|
||||
if package:
|
||||
module.__path__ = []
|
||||
if attrs:
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
parent_name, _, child_name = name.rpartition(".")
|
||||
if parent_name:
|
||||
parent = sys.modules.get(parent_name)
|
||||
if parent is None:
|
||||
try:
|
||||
parent = importlib.import_module(parent_name)
|
||||
except Exception:
|
||||
parent = types.ModuleType(parent_name)
|
||||
parent.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, parent_name, parent)
|
||||
setattr(parent, child_name, module)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class FakeCursor:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
fetchone_results=None,
|
||||
fetchall_results=None,
|
||||
rowcount: int = 0,
|
||||
rowcounts=None,
|
||||
):
|
||||
self._fetchone_results = list(fetchone_results or [])
|
||||
self._fetchall_results = list(fetchall_results or [])
|
||||
self._rowcounts = list(rowcounts or [])
|
||||
self.rowcount = rowcount
|
||||
self.executed: list[tuple[str, dict | tuple | None]] = []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def execute(self, query, params=None):
|
||||
self.executed.append((str(query), params))
|
||||
if self._rowcounts:
|
||||
self.rowcount = self._rowcounts.pop(0)
|
||||
|
||||
async def fetchone(self):
|
||||
if self._fetchone_results:
|
||||
return self._fetchone_results.pop(0)
|
||||
return None
|
||||
|
||||
async def fetchall(self):
|
||||
if self._fetchall_results:
|
||||
return self._fetchall_results.pop(0)
|
||||
return []
|
||||
|
||||
|
||||
class FakeConnection:
|
||||
def __init__(self, cursor: FakeCursor):
|
||||
self.cursor_instance = cursor
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def cursor(self):
|
||||
return self.cursor_instance
|
||||
|
||||
|
||||
class FakeDB:
|
||||
def __init__(self, cursor: FakeCursor):
|
||||
self.connection = FakeConnection(cursor)
|
||||
|
||||
def get_connection(self):
|
||||
return self.connection
|
||||
|
||||
|
||||
class FakeExecuteResult:
|
||||
def __init__(self, *, rows=None, scalar_value=None):
|
||||
self._rows = list(rows or [])
|
||||
self._scalar_value = scalar_value
|
||||
|
||||
def scalars(self):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return self._rows
|
||||
|
||||
def scalar(self):
|
||||
return self._scalar_value
|
||||
|
||||
|
||||
class FakeAsyncSession:
|
||||
def __init__(self, execute_results=None):
|
||||
self._execute_results = list(execute_results or [])
|
||||
self.executed = []
|
||||
self.added = []
|
||||
self.commit_count = 0
|
||||
self.refreshed = []
|
||||
|
||||
def add(self, obj):
|
||||
self.added.append(obj)
|
||||
|
||||
async def execute(self, stmt):
|
||||
self.executed.append(stmt)
|
||||
if self._execute_results:
|
||||
return self._execute_results.pop(0)
|
||||
return FakeExecuteResult()
|
||||
|
||||
async def commit(self):
|
||||
self.commit_count += 1
|
||||
|
||||
async def refresh(self, obj):
|
||||
self.refreshed.append(obj)
|
||||
|
||||
|
||||
def make_user(**overrides):
|
||||
from app.domain.models.role import UserRole
|
||||
from app.domain.schemas.user import UserInDB
|
||||
|
||||
data = {
|
||||
"id": 1,
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"hashed_password": "hashed-password",
|
||||
"role": UserRole.USER,
|
||||
"is_active": True,
|
||||
"is_superuser": False,
|
||||
"created_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
"updated_at": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
data.update(overrides)
|
||||
return UserInDB(**data)
|
||||
|
||||
|
||||
def make_audit_log(**overrides):
|
||||
data = {
|
||||
"id": uuid4(),
|
||||
"user_id": uuid4(),
|
||||
"project_id": uuid4(),
|
||||
"action": "LOGIN",
|
||||
"resource_type": "user",
|
||||
"resource_id": "1",
|
||||
"ip_address": "127.0.0.1",
|
||||
"request_method": "GET",
|
||||
"request_path": "/audit/logs",
|
||||
"request_data": {"ok": True},
|
||||
"response_status": 200,
|
||||
"timestamp": datetime(2025, 1, 1, tzinfo=timezone.utc),
|
||||
}
|
||||
data.update(overrides)
|
||||
return SimpleNamespace(**data)
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from app.infra.db.metadb.repositories.audit_repository import AuditRepository
|
||||
from tests.conftest import FakeAsyncSession, FakeExecuteResult, make_audit_log
|
||||
|
||||
|
||||
def test_create_log_adds_commits_and_refreshes(monkeypatch):
|
||||
class FakeAuditLog:
|
||||
def __init__(self, **kwargs):
|
||||
self.id = uuid4()
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
session = FakeAsyncSession()
|
||||
repo = AuditRepository(session)
|
||||
monkeypatch.setattr(
|
||||
"app.infra.db.metadb.repositories.audit_repository.models.AuditLog",
|
||||
FakeAuditLog,
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
repo.create_log(
|
||||
action="LOGIN",
|
||||
request_method="POST",
|
||||
request_path="/auth/login",
|
||||
response_status=200,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.action == "LOGIN"
|
||||
assert result.request_method == "POST"
|
||||
assert session.commit_count == 1
|
||||
assert len(session.added) == 1
|
||||
assert len(session.refreshed) == 1
|
||||
|
||||
|
||||
def test_get_logs_builds_filtered_query_and_returns_models():
|
||||
log = make_audit_log(action="UPDATE_USER", resource_type="user")
|
||||
session = FakeAsyncSession(
|
||||
execute_results=[FakeExecuteResult(rows=[log])],
|
||||
)
|
||||
repo = AuditRepository(session)
|
||||
user_id = uuid4()
|
||||
project_id = uuid4()
|
||||
start_time = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
results = asyncio.run(
|
||||
repo.get_logs(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
action="UPDATE_USER",
|
||||
resource_type="user",
|
||||
start_time=start_time,
|
||||
skip=5,
|
||||
limit=10,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].action == "UPDATE_USER"
|
||||
stmt = session.executed[0]
|
||||
assert len(stmt._where_criteria) == 5
|
||||
assert stmt._offset == 5
|
||||
assert stmt._limit == 10
|
||||
|
||||
|
||||
def test_get_log_count_returns_zero_when_scalar_none():
|
||||
session = FakeAsyncSession(
|
||||
execute_results=[FakeExecuteResult(scalar_value=None)],
|
||||
)
|
||||
repo = AuditRepository(session)
|
||||
|
||||
result = asyncio.run(repo.get_log_count(action="DELETE_USER"))
|
||||
|
||||
assert result == 0
|
||||
stmt = session.executed[0]
|
||||
assert len(stmt._where_criteria) == 1
|
||||
@@ -0,0 +1,97 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.auth import dependencies
|
||||
from app.core.security import create_access_token, create_refresh_token
|
||||
from tests.conftest import make_user
|
||||
|
||||
|
||||
def test_get_db_returns_app_state_db():
|
||||
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace(db="db-instance")))
|
||||
|
||||
result = asyncio.run(dependencies.get_db(request))
|
||||
|
||||
assert result == "db-instance"
|
||||
|
||||
|
||||
def test_get_db_raises_when_database_missing():
|
||||
request = SimpleNamespace(app=SimpleNamespace(state=SimpleNamespace()))
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(dependencies.get_db(request))
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert exc_info.value.detail == "Database not initialized"
|
||||
|
||||
|
||||
def test_get_current_user_accepts_valid_access_token():
|
||||
repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=make_user()))
|
||||
|
||||
result = asyncio.run(
|
||||
dependencies.get_current_user(
|
||||
token=create_access_token("tester"),
|
||||
user_repo=repo,
|
||||
)
|
||||
)
|
||||
|
||||
assert result.username == "tester"
|
||||
repo.get_user_by_username.assert_awaited_once_with("tester")
|
||||
|
||||
|
||||
def test_get_current_user_rejects_refresh_token():
|
||||
repo = SimpleNamespace(get_user_by_username=AsyncMock())
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
dependencies.get_current_user(
|
||||
token=create_refresh_token("tester"),
|
||||
user_repo=repo,
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Invalid token type. Access token required."
|
||||
repo.get_user_by_username.assert_not_awaited()
|
||||
|
||||
|
||||
def test_get_current_user_rejects_missing_user():
|
||||
repo = SimpleNamespace(get_user_by_username=AsyncMock(return_value=None))
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
dependencies.get_current_user(
|
||||
token=create_access_token("ghost"),
|
||||
user_repo=repo,
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert exc_info.value.detail == "Could not validate credentials"
|
||||
|
||||
|
||||
def test_get_current_active_user_rejects_inactive_user():
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
dependencies.get_current_active_user(
|
||||
current_user=make_user(is_active=False),
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "Inactive user"
|
||||
|
||||
|
||||
def test_get_current_superuser_rejects_non_superuser():
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(
|
||||
dependencies.get_current_superuser(
|
||||
current_user=make_user(is_superuser=False),
|
||||
)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "Not enough privileges. Superuser access required."
|
||||
@@ -0,0 +1,56 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.auth import permissions
|
||||
from app.domain.models.role import UserRole
|
||||
from tests.conftest import make_user
|
||||
|
||||
|
||||
def test_require_role_allows_higher_privilege_user():
|
||||
checker = permissions.require_role(UserRole.OPERATOR)
|
||||
|
||||
result = asyncio.run(checker(current_user=make_user(role=UserRole.ADMIN)))
|
||||
|
||||
assert result.role == UserRole.ADMIN
|
||||
|
||||
|
||||
def test_require_role_rejects_insufficient_role():
|
||||
checker = permissions.require_role(UserRole.ADMIN)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(checker(current_user=make_user(role=UserRole.USER)))
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Required role: ADMIN" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_check_resource_owner_allows_admin():
|
||||
assert permissions.check_resource_owner(
|
||||
99,
|
||||
make_user(id=1, role=UserRole.ADMIN),
|
||||
) is True
|
||||
|
||||
|
||||
def test_check_resource_owner_allows_owner():
|
||||
assert permissions.check_resource_owner(
|
||||
7,
|
||||
make_user(id=7, role=UserRole.USER),
|
||||
) is True
|
||||
|
||||
|
||||
def test_check_resource_owner_rejects_other_user():
|
||||
assert permissions.check_resource_owner(
|
||||
7,
|
||||
make_user(id=8, role=UserRole.USER),
|
||||
) is False
|
||||
|
||||
|
||||
def test_require_owner_or_admin_rejects_other_user():
|
||||
checker = permissions.require_owner_or_admin(7)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(checker(current_user=make_user(id=8, role=UserRole.USER)))
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert exc_info.value.detail == "You don't have permission to access this resource"
|
||||
@@ -1,9 +1,8 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _load_scada_repository():
|
||||
module_path = (
|
||||
@@ -50,18 +49,19 @@ class _FakeConnection:
|
||||
return self.cursor_instance
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_scada_field_inserts_when_update_hits_no_rows():
|
||||
def test_update_scada_field_inserts_when_update_hits_no_rows():
|
||||
ScadaRepository = _load_scada_repository()
|
||||
conn = _FakeConnection(initial_rowcount=0)
|
||||
point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
await ScadaRepository.update_scada_field(
|
||||
conn,
|
||||
point_time,
|
||||
"170490",
|
||||
"cleaned_value",
|
||||
26.5,
|
||||
asyncio.run(
|
||||
ScadaRepository.update_scada_field(
|
||||
conn,
|
||||
point_time,
|
||||
"170490",
|
||||
"cleaned_value",
|
||||
26.5,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(conn.cursor_instance.calls) == 2
|
||||
@@ -69,18 +69,19 @@ async def test_update_scada_field_inserts_when_update_hits_no_rows():
|
||||
assert "INSERT INTO scada.scada_data" in conn.cursor_instance.calls[1][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_scada_field_skips_insert_when_update_succeeds():
|
||||
def test_update_scada_field_skips_insert_when_update_succeeds():
|
||||
ScadaRepository = _load_scada_repository()
|
||||
conn = _FakeConnection(initial_rowcount=1)
|
||||
point_time = datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
await ScadaRepository.update_scada_field(
|
||||
conn,
|
||||
point_time,
|
||||
"170490",
|
||||
"cleaned_value",
|
||||
26.5,
|
||||
asyncio.run(
|
||||
ScadaRepository.update_scada_field(
|
||||
conn,
|
||||
point_time,
|
||||
"170490",
|
||||
"cleaned_value",
|
||||
26.5,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(conn.cursor_instance.calls) == 1
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.domain.models.role import UserRole
|
||||
from app.domain.schemas.user import UserCreate, UserUpdate
|
||||
from app.infra.db.metadb.repositories.user_repository import UserRepository
|
||||
from tests.conftest import FakeCursor, FakeDB
|
||||
|
||||
|
||||
def _user_row(**overrides):
|
||||
base = {
|
||||
"id": 1,
|
||||
"username": "tester",
|
||||
"email": "tester@example.com",
|
||||
"hashed_password": "hashed-password",
|
||||
"role": "USER",
|
||||
"is_active": True,
|
||||
"is_superuser": False,
|
||||
"created_at": "2025-01-01T00:00:00+00:00",
|
||||
"updated_at": "2025-01-01T00:00:00+00:00",
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def test_create_user_hashes_password_and_returns_model(monkeypatch):
|
||||
cursor = FakeCursor(fetchone_results=[_user_row()])
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
monkeypatch.setattr(
|
||||
"app.infra.db.metadb.repositories.user_repository.get_password_hash",
|
||||
lambda password: f"hashed::{password}",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
repo.create_user(
|
||||
UserCreate(
|
||||
username="tester",
|
||||
email="tester@example.com",
|
||||
password="secret123",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.username == "tester"
|
||||
assert cursor.executed[0][1]["hashed_password"] == "hashed::secret123"
|
||||
|
||||
|
||||
def test_update_user_without_fields_returns_existing_user(monkeypatch):
|
||||
repo = UserRepository(FakeDB(FakeCursor()))
|
||||
existing_user = AsyncMock(return_value="existing")
|
||||
monkeypatch.setattr(repo, "get_user_by_id", existing_user)
|
||||
|
||||
result = asyncio.run(repo.update_user(1, UserUpdate()))
|
||||
|
||||
assert result == "existing"
|
||||
existing_user.assert_awaited_once_with(1)
|
||||
|
||||
|
||||
def test_update_user_builds_dynamic_query(monkeypatch):
|
||||
cursor = FakeCursor(fetchone_results=[_user_row(role="ADMIN", email="new@example.com")])
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
monkeypatch.setattr(
|
||||
"app.infra.db.metadb.repositories.user_repository.get_password_hash",
|
||||
lambda password: f"hashed::{password}",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
repo.update_user(
|
||||
1,
|
||||
UserUpdate(
|
||||
email="new@example.com",
|
||||
password="new-secret",
|
||||
role=UserRole.ADMIN,
|
||||
is_active=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
query, params = cursor.executed[0]
|
||||
assert "email = %(email)s" in query
|
||||
assert "hashed_password = %(hashed_password)s" in query
|
||||
assert "role = %(role)s" in query
|
||||
assert "is_active = %(is_active)s" in query
|
||||
assert params["hashed_password"] == "hashed::new-secret"
|
||||
assert params["role"] == "ADMIN"
|
||||
assert params["is_active"] is False
|
||||
|
||||
|
||||
def test_delete_user_returns_false_when_execute_raises():
|
||||
cursor = FakeCursor()
|
||||
cursor.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
|
||||
result = asyncio.run(repo.delete_user(1))
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_user_exists_short_circuits_without_filters():
|
||||
cursor = FakeCursor()
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
|
||||
result = asyncio.run(repo.user_exists())
|
||||
|
||||
assert result is False
|
||||
assert cursor.executed == []
|
||||
|
||||
|
||||
def test_user_exists_checks_username_or_email():
|
||||
cursor = FakeCursor(fetchone_results=[{"exists": True}])
|
||||
repo = UserRepository(FakeDB(cursor))
|
||||
|
||||
result = asyncio.run(
|
||||
repo.user_exists(username="tester", email="tester@example.com")
|
||||
)
|
||||
|
||||
assert result is True
|
||||
query, params = cursor.executed[0]
|
||||
assert "username = %(username)s OR email = %(email)s" in query
|
||||
assert params == {"username": "tester", "email": "tester@example.com"}
|
||||
+76
@@ -0,0 +1,76 @@
|
||||
# `app/api/v1/endpoints/` 失效 API 排查与修正
|
||||
|
||||
排查范围:`app/api/v1/endpoints/`
|
||||
|
||||
结论:本次共确认 5 个问题接口,处理结果如下:
|
||||
|
||||
- **已删除 4 个未实现坏接口**
|
||||
- **已修正 1 个签名失配接口**
|
||||
|
||||
> 路由统一前缀来自 `app/main.py:71`,以下完整路径均以 `/api/v1` 开头。
|
||||
|
||||
## 处理结果
|
||||
|
||||
| Method | API | 原问题 | 处理结果 |
|
||||
| --- | --- | --- | --- |
|
||||
| GET | `/api/v1/calculateregion/` | 调用时 `NameError`,底层无 `calculate_region` 实现 | **已删除** |
|
||||
| GET | `/api/v1/getallregions/` | 调用时 `NameError`,底层无 `get_all_regions` 实现 | **已删除** |
|
||||
| POST | `/api/v1/generateregion/` | 调用时 `NameError`,底层无 `generate_region` 实现 | **已删除** |
|
||||
| GET | `/api/v1/calculatedistrictmeteringarea/` | 调用时 `NameError`,仍指向已废弃旧 DMA 入口 | **已删除** |
|
||||
| GET | `/api/v1/calculateservicearea/` | endpoint 传 `time_index`,实现只接受 `name` | **已修正**,现返回全部时间步结果 |
|
||||
|
||||
## 删除原因
|
||||
|
||||
### 1. region 相关 3 个接口
|
||||
|
||||
以下能力在当前 `wndb` / `tjnetwork` 中均不存在:
|
||||
|
||||
- `calculate_region`
|
||||
- `get_all_regions`
|
||||
- `generate_region`
|
||||
|
||||
`app/native/wndb/__init__.py` 当前只提供 region CRUD 和 util 能力,不提供 region 计算或批量查询能力。因此这 3 个接口继续保留只会在运行时失败。
|
||||
|
||||
### 2. DMA 旧入口
|
||||
|
||||
旧接口 `calculate_district_metering_area(...)` 已不存在,当前只保留 3 个明确变体:
|
||||
|
||||
- `/api/v1/calculatedistrictmeteringareafornodes/`
|
||||
- `/api/v1/calculatedistrictmeteringareaforregion/`
|
||||
- `/api/v1/calculatedistrictmeteringareafornetwork/`
|
||||
|
||||
因此旧入口 `/api/v1/calculatedistrictmeteringarea/` 已删除,避免前端继续误用历史接口。
|
||||
|
||||
## 修正内容
|
||||
|
||||
### `GET /api/v1/calculateservicearea/`
|
||||
|
||||
原接口问题:
|
||||
|
||||
- endpoint 定义保留 `time_index`
|
||||
- 实际实现 `calculate_service_area(name)` 只接收 `network/name`
|
||||
- 调用时会触发参数数量不匹配
|
||||
|
||||
本次修正后:
|
||||
|
||||
- 移除 `time_index` 查询参数
|
||||
- 返回类型改为 `list[dict[str, list[str]]]`
|
||||
- 接口语义改为:**返回全部时间步的服务区计算结果**
|
||||
|
||||
## 当前可用替代接口
|
||||
|
||||
DMA 计算请使用:
|
||||
|
||||
- `/api/v1/calculatedistrictmeteringareafornodes/`
|
||||
- `/api/v1/calculatedistrictmeteringareaforregion/`
|
||||
- `/api/v1/calculatedistrictmeteringareafornetwork/`
|
||||
|
||||
服务区计算请使用:
|
||||
|
||||
- `/api/v1/calculateservicearea/`
|
||||
现在返回全部时间步结果,不再接收 `time_index`
|
||||
|
||||
## 变更文件
|
||||
|
||||
- `app/api/v1/endpoints/network/regions.py`
|
||||
- `失效API排查.md`
|
||||
Reference in New Issue
Block a user