From 4761a0d60bb10699d2702880082cb99bb30a3dc0 Mon Sep 17 00:00:00 2001 From: wqy Date: Sat, 17 Sep 2022 09:56:54 +0800 Subject: [PATCH] Add link in base --- api/s0_base.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/api/s0_base.py b/api/s0_base.py index 37d9488..e4a5032 100644 --- a/api/s0_base.py +++ b/api/s0_base.py @@ -3,6 +3,7 @@ from .connection import g_conn_dict as conn from .operation import * from .change_set import ChangeSet + _NODE = "_node" _LINK = "_link" _CURVE = "_curve" @@ -15,47 +16,59 @@ PIPE = "pipe" PUMP = "pump" VALVE = "valve" + def _get_from(name: str, id: str, base_type: str) -> Row | None: with conn[name].cursor(row_factory=dict_row) as cur: cur.execute(f"select * from {base_type} where id = '{id}'") return cur.fetchone() + def is_node(name: str, id: str) -> bool: return _get_from(name, id, _NODE) != None + def is_junction(name: str, id: str) -> bool: row = _get_from(name, id, _NODE) return row != None and row['type'] == JUNCTION + def is_reservoir(name: str, id: str) -> bool: row = _get_from(name, id, _NODE) return row != None and row['type'] == RESERVOIR + def is_tank(name: str, id: str) -> bool: row = _get_from(name, id, _NODE) return row != None and row['type'] == TANK + def is_link(name: str, id: str) -> bool: return _get_from(name, id, _LINK) != {} + def is_pipe(name: str, id: str) -> bool: row = _get_from(name, id, _LINK) return row != None and row['type'] == PIPE + def is_pump(name: str, id: str) -> bool: row = _get_from(name, id, _LINK) return row != None and row['type'] == PUMP + def is_valve(name: str, id: str) -> bool: row = _get_from(name, id, _LINK) return row != None and row['type'] == VALVE + def is_curve(name: str, id: str) -> bool: return _get_from(name, id, _CURVE) != None + def is_pattern(name: str, id: str) -> bool: return _get_from(name, id, _PATTERN) != None + def _get_all(name: str, base_type: str) -> list[str]: ids : list[str] = [] with conn[name].cursor(row_factory=dict_row) as cur: @@ -64,18 +77,23 @@ def _get_all(name: str, base_type: str) -> list[str]: ids.append(record['id']) return ids + def get_nodes(name: str) -> list[str]: return _get_all(name, _NODE) + def get_links(name: str) -> list[str]: return _get_all(name, _LINK) + def get_curves(name: str) -> list[str]: return _get_all(name, _CURVE) + def get_patterns(name: str) -> list[str]: return _get_all(name, _PATTERN) + def add_node(name: str, node_type: str, id: str, x: float, y: float, table_sql: str, table_undo_sql: str) -> ChangeSet: if is_node(name, id): return @@ -96,6 +114,7 @@ def add_node(name: str, node_type: str, id: str, x: float, y: float, table_sql: change.add(node_type, id) return change + def delete_node(name: str, node_type: str, id: str, table_sql: str, table_undo_sql: str) -> ChangeSet: if not is_node(name, id): return @@ -122,3 +141,41 @@ def delete_node(name: str, node_type: str, id: str, table_sql: str, table_undo_s change = ChangeSet() change.delete(node_type, id) return change + + +def add_link(name: str, link_type: str, id: str, table_sql: str, table_undo_sql: str) -> ChangeSet: + if is_link(name, id): + return + + with conn[name].cursor() as cur: + sql = f"insert into _link (id, type) values ('{id}', '{link_type}'); " + sql += table_sql + cur.execute(sql) + + redo = sql.replace("'", '"') + undo = table_undo_sql + undo += f' delete from _link where id = "{id}";' + add_operation(name, redo, undo) + + change = ChangeSet() + change.add(link_type, id) + return change + + +def delete_link(name: str, link_type: str, id: str, table_sql: str, table_undo_sql: str) -> ChangeSet: + if not is_node(name, id): + return + + with conn[name].cursor(row_factory=dict_row) as cur: + sql = table_sql + sql += f" delete from _link where id = '{id}';" + cur.execute(sql) + + redo = sql.replace("'", '"') + undo = f'insert into _link (id, type) values ("{id}", "{link_type}"); ' + undo += table_undo_sql + add_operation(name, redo, undo) + + change = ChangeSet() + change.delete(link_type, id) + return change