修正单元测试失败代码

This commit is contained in:
2026-05-25 17:51:45 +08:00
parent 2317f4d527
commit 88be97ddeb
6 changed files with 171 additions and 160 deletions
+129 -117
View File
@@ -1,97 +1,96 @@
import os
import ctypes
from .project import have_project
from .inp_out import dump_inp
def calculate_service_area(name: str) -> list[dict[str, list[str]]]:
if not have_project(name):
raise Exception(f'Not found project [{name}]')
dir = os.path.abspath(os.getcwd())
inp_str = os.path.join(os.path.join(dir, 'db_inp'), name + '.db.inp')
dump_inp(name, inp_str, '2')
toolkit = ctypes.CDLL(os.path.join(os.path.join(dir, 'api'), 'toolkit.dll'))
inp = ctypes.c_char_p(inp_str.encode())
handle = ctypes.c_ulonglong()
toolkit.TK_ServiceArea_Start(inp, ctypes.byref(handle))
c_nodeCount = ctypes.c_size_t()
toolkit.TK_ServiceArea_GetNodeCount(handle, ctypes.byref(c_nodeCount))
nodeCount = c_nodeCount.value
nodeIds: list[str] = []
for n in range(0, nodeCount):
id = ctypes.c_char_p()
toolkit.TK_ServiceArea_GetNodeId(handle, ctypes.c_size_t(n), ctypes.byref(id))
nodeIds.append(id.value.decode())
c_timeCount = ctypes.c_size_t()
toolkit.TK_ServiceArea_GetTimeCount(handle, ctypes.byref(c_timeCount))
timeCount = c_timeCount.value
results: list[dict[str, list[str]]] = []
for t in range(0, timeCount):
c_sourceCount = ctypes.c_size_t()
toolkit.TK_ServiceArea_GetSourceCount(handle, ctypes.c_size_t(t), ctypes.byref(c_sourceCount))
sourceCount = c_sourceCount.value
sources = ctypes.POINTER(ctypes.c_size_t)()
toolkit.TK_ServiceArea_GetSources(handle, ctypes.c_size_t(t), ctypes.byref(sources))
result: dict[str, list[str]] = {}
for s in range(0, sourceCount):
result[nodeIds[sources[s]]] = []
for n in range(0, nodeCount):
concentration = ctypes.POINTER(ctypes.c_double)()
toolkit.TK_ServiceArea_GetConcentration(handle, ctypes.c_size_t(t), ctypes.c_size_t(n), ctypes.byref(concentration))
maxS = sources[0]
maxC = concentration[0]
for s in range(1, sourceCount):
if concentration[s] > maxC:
maxS = sources[s]
maxC = concentration[s]
result[nodeIds[maxS]].append(nodeIds[n])
results.append(result)
toolkit.TK_ServiceArea_End(handle)
return results
'''
import sys
import json
import platform
import subprocess
import uuid
from queue import Queue
from .database import *
from .s0_base import get_node_links, get_link_nodes
from typing import Any
sys.path.append('..')
from app.infra.epanet.epanet import run_project
from app.infra.epanet.epanet import Output
def _calculate_service_area(name: str, inp, time_index: int = 0) -> dict[str, list[str]]:
sources : dict[str, list[str]] = {}
for node_result in inp['node_results']:
from .inp_out import dump_inp
from .project import have_project
from .s0_base import get_link_nodes, get_node_links
from .s23_options_util import get_option_v3
def _update_section(lines: list[str], section: str, transform) -> list[str]:
result: list[str] = []
i = 0
while i < len(lines):
line = lines[i]
if line.strip() == f'[{section}]':
result.append(line)
i += 1
section_lines: list[str] = []
while i < len(lines) and not lines[i].startswith('['):
section_lines.append(lines[i])
i += 1
result.extend(transform(section_lines))
continue
result.append(line)
i += 1
return result
def _build_service_area_input(name: str, inp_path: str) -> None:
dump_inp(name, inp_path, '2')
with open(inp_path, encoding='utf-8') as file:
lines = file.read().splitlines()
unbalanced = get_option_v3(name).get('IF_UNBALANCED', '').strip()
if unbalanced != '':
lines = _update_section(
lines,
'OPTIONS',
lambda option_lines: [
f'UNBALANCED {unbalanced}' if line.startswith('UNBALANCED ') else line
for line in option_lines
],
)
with open(inp_path, mode='w', encoding='utf-8') as file:
file.write('\n'.join(lines) + '\n')
def _run_epanet_output(inp_path: str, rpt_path: str, out_path: str) -> dict[str, Any]:
epanet_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'infra', 'epanet'))
if platform.system() == 'Windows':
exe = os.path.join(epanet_dir, 'windows', 'runepanet.exe')
else:
exe = os.path.join(epanet_dir, 'linux', 'runepanet')
if not os.access(exe, os.X_OK):
os.chmod(exe, 0o755)
env = os.environ.copy()
if platform.system() == 'Linux':
lib_dir = os.path.dirname(exe)
env['LD_LIBRARY_PATH'] = f"{lib_dir}:{env.get('LD_LIBRARY_PATH', '')}"
process = subprocess.run([exe, inp_path, rpt_path, out_path], env=env, capture_output=True, text=True)
if process.returncode != 0:
raise RuntimeError(
f'EPANET failed for [{inp_path}] with code {process.returncode}: '
f'stdout={process.stdout} stderr={process.stderr}'
)
return Output(out_path).dump()
def _calculate_service_area(name: str, output: dict[str, Any], time_index: int) -> dict[str, list[str]]:
sources: dict[str, list[str]] = {}
for node_result in output['node_results']:
result = node_result['result'][time_index]
if result['demand'] < 0:
sources[node_result['node']] = []
link_flows: dict[str, float] = {}
for link_result in inp['link_results']:
for link_result in output['link_results']:
result = link_result['result'][time_index]
link_flows[link_result['link']] = float(result['flow'])
# build source to nodes map
for source in sources:
queue = Queue()
queue: Queue[str] = Queue()
queue.put(source)
while not queue.empty():
@@ -107,9 +106,6 @@ def _calculate_service_area(name: str, inp, time_index: int = 0) -> dict[str, li
elif node2 == cursor and link_flows[link] < 0:
queue.put(node1)
#return sources
# calculation concentration
concentration_map: dict[str, dict[str, float]] = {}
node_wip: list[str] = []
for source, nodes in sources.items():
@@ -120,17 +116,15 @@ def _calculate_service_area(name: str, inp, time_index: int = 0) -> dict[str, li
if node not in node_wip:
node_wip.append(node)
# if only one source, done
for node, concentrations in concentration_map.items():
if len(concentrations) == 1:
node_wip.remove(node)
for key in concentrations.keys():
concentration_map[node][key] = 1.0
for source in concentrations.keys():
concentration_map[node][source] = 1.0
node_upstream : dict[str, list[tuple[str, str]]] = {}
node_upstream: dict[str, list[tuple[str, str]]] = {}
for node in node_wip:
if node not in node_upstream:
node_upstream[node] = []
node_upstream[node] = []
links = get_node_links(name, node)
for link in links:
@@ -141,7 +135,7 @@ def _calculate_service_area(name: str, inp, time_index: int = 0) -> dict[str, li
node_upstream[node].append((link, node2))
while len(node_wip) != 0:
done = []
done: list[str] = []
for node in node_wip:
up_link_nodes = node_upstream[node]
ready = True
@@ -149,33 +143,38 @@ def _calculate_service_area(name: str, inp, time_index: int = 0) -> dict[str, li
if link_node[1] in node_wip:
ready = False
break
if ready:
for link_node in up_link_nodes:
if link_node[1] not in concentration_map.keys():
continue
for source, concentration in concentration_map[link_node[1]].items():
concentration_map[node][source] += concentration * abs(link_flows[link_node[0]])
if not ready:
continue
# normalize
sum = 0.0
for source, concentration in concentration_map[node].items():
sum += concentration
for source in concentration_map[node].keys():
concentration_map[node][source] /= sum
for link, upstream_node in up_link_nodes:
if upstream_node not in concentration_map:
continue
for source, concentration in concentration_map[upstream_node].items():
concentration_map[node][source] += concentration * abs(link_flows[link])
done.append(node)
total_concentration = sum(concentration_map[node].values())
if total_concentration == 0:
raise RuntimeError(f'Failed to normalize service area concentration for node [{node}] at time [{time_index}]')
for source in concentration_map[node].keys():
concentration_map[node][source] /= total_concentration
done.append(node)
if len(done) == 0:
raise RuntimeError(f'Failed to resolve service area graph for time [{time_index}]')
for node in done:
node_wip.remove(node)
source_to_main_node: dict[str, list[str]] = {}
for node, value in concentration_map.items():
for node, concentrations in concentration_map.items():
max_source = ''
max_concentration = 0.0
for s, c in value.items():
if c > max_concentration:
max_concentration = c
max_source = s
for source, concentration in concentrations.items():
if concentration > max_concentration:
max_concentration = concentration
max_source = source
if max_source not in source_to_main_node:
source_to_main_node[max_source] = []
source_to_main_node[max_source].append(node)
@@ -184,15 +183,28 @@ def _calculate_service_area(name: str, inp, time_index: int = 0) -> dict[str, li
def calculate_service_area(name: str) -> list[dict[str, list[str]]]:
inp = json.loads(run_project(name, True))
if not have_project(name):
raise Exception(f'Not found project [{name}]')
result: list[dict[str, list[str]]] = []
root = os.path.abspath(os.getcwd())
token = f'{os.getpid()}_{uuid.uuid4().hex}'
inp_path = os.path.join(root, 'db_inp', f'{name}.service_area.{token}.inp')
rpt_path = os.path.join(root, 'temp', f'{name}.service_area.{token}.rpt')
out_path = os.path.join(root, 'temp', f'{name}.service_area.{token}.opt')
time_count = len(inp['node_results'][0]['result'])
os.makedirs(os.path.dirname(inp_path), exist_ok=True)
os.makedirs(os.path.dirname(rpt_path), exist_ok=True)
for i in range(time_count):
sas = _calculate_service_area(name, inp, i)
result.append(sas)
try:
_build_service_area_input(name, inp_path)
output = _run_epanet_output(inp_path, rpt_path, out_path)
return result
'''
results: list[dict[str, list[str]]] = []
time_count = len(output['node_results'][0]['result'])
for time_index in range(time_count):
results.append(_calculate_service_area(name, output, time_index))
return results
finally:
for path in (inp_path, rpt_path, out_path):
if os.path.exists(path):
os.remove(path)