From a3f88c61fabb773417c45a9f77f4291dcec9aa95 Mon Sep 17 00:00:00 2001 From: wqy Date: Fri, 28 Jul 2023 19:36:38 +0800 Subject: [PATCH] Switch metis to pymetis --- api/s33_dma_cal.py | 56 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/api/s33_dma_cal.py b/api/s33_dma_cal.py index 5f30967..e2d7ae9 100644 --- a/api/s33_dma_cal.py +++ b/api/s33_dma_cal.py @@ -1,17 +1,71 @@ import ctypes import os +import numpy as np +import pymetis from .database import * from .s0_base import get_nodes from .s32_region_util import get_nodes_in_region from .s32_region_util import Topology -from .s32_region import get_region PARTITION_TYPE_RB = 0 PARTITION_TYPE_KWAY = 1 +''' +adjacency_list = [np.array([4, 2, 1]), + np.array([0, 2, 3]), + np.array([4, 3, 1, 0]), + np.array([1, 2, 5, 6]), + np.array([0, 2, 5]), + np.array([4, 3, 6]), + np.array([5, 3])] +n_cuts, membership = pymetis.part_graph(2, adjacency=adjacency_list) +# n_cuts = 3 +# membership = [1, 1, 1, 0, 1, 0, 0] + +nodes_part_0 = np.argwhere(np.array(membership) == 0).ravel() # [3, 5, 6] +nodes_part_1 = np.argwhere(np.array(membership) == 1).ravel() # [0, 1, 2, 4] + +print(nodes_part_0) +print(nodes_part_1) +''' + def calculate_district_metering_area_for_nodes(name: str, nodes: list[str], part_count: int = 1, part_type: int = PARTITION_TYPE_RB) -> list[list[str]]: + topology = Topology(name, nodes) + t_nodes = topology.nodes() + t_links = topology.links() + t_node_list = topology.node_list() + + adjacency_list = [] + + for node in t_node_list: + links: list[str] = t_nodes[node]['links'] + a_nodes: list[int] = [] + for link in links: + if t_links[link]['node1'] == node: + i = t_node_list.index(t_links[link]['node2']) + a_nodes.append(i) + elif t_links[link]['node2'] == node: + i = t_node_list.index(t_links[link]['node1']) + a_nodes.append(i) + adjacency_list.append(np.array(a_nodes)) + + recursive = part_type == PARTITION_TYPE_RB + n_cuts, membership = pymetis.part_graph(nparts=part_count, adjacency=adjacency_list, recursive=recursive, contiguous=True) + + result: list[list[str]] = [] + for i in range(0, part_count): + indices: list[int] = list(np.argwhere(np.array(membership) == i).ravel()) + index_strs: list[str] = [] + for index in indices: + index_strs.append(t_node_list[index]) + result.append(index_strs) + + return result + + +def _calculate_district_metering_area_for_nodes(name: str, nodes: list[str], part_count: int = 1, part_type: int = PARTITION_TYPE_RB) -> list[list[str]]: if part_type != PARTITION_TYPE_RB and part_type != PARTITION_TYPE_KWAY: return [] if part_count <= 0: