Add link in base

This commit is contained in:
wqy
2022-09-17 09:56:54 +08:00
parent 435411356f
commit 4761a0d60b

View File

@@ -3,6 +3,7 @@ from .connection import g_conn_dict as conn
from .operation import *
from .change_set import ChangeSet
_NODE = "_node"
_LINK = "_link"
_CURVE = "_curve"
@@ -15,47 +16,59 @@ PIPE = "pipe"
PUMP = "pump"
VALVE = "valve"
def _get_from(name: str, id: str, base_type: str) -> Row | None:
with conn[name].cursor(row_factory=dict_row) as cur:
cur.execute(f"select * from {base_type} where id = '{id}'")
return cur.fetchone()
def is_node(name: str, id: str) -> bool:
return _get_from(name, id, _NODE) != None
def is_junction(name: str, id: str) -> bool:
row = _get_from(name, id, _NODE)
return row != None and row['type'] == JUNCTION
def is_reservoir(name: str, id: str) -> bool:
row = _get_from(name, id, _NODE)
return row != None and row['type'] == RESERVOIR
def is_tank(name: str, id: str) -> bool:
row = _get_from(name, id, _NODE)
return row != None and row['type'] == TANK
def is_link(name: str, id: str) -> bool:
return _get_from(name, id, _LINK) != {}
def is_pipe(name: str, id: str) -> bool:
row = _get_from(name, id, _LINK)
return row != None and row['type'] == PIPE
def is_pump(name: str, id: str) -> bool:
row = _get_from(name, id, _LINK)
return row != None and row['type'] == PUMP
def is_valve(name: str, id: str) -> bool:
row = _get_from(name, id, _LINK)
return row != None and row['type'] == VALVE
def is_curve(name: str, id: str) -> bool:
return _get_from(name, id, _CURVE) != None
def is_pattern(name: str, id: str) -> bool:
return _get_from(name, id, _PATTERN) != None
def _get_all(name: str, base_type: str) -> list[str]:
ids : list[str] = []
with conn[name].cursor(row_factory=dict_row) as cur:
@@ -64,18 +77,23 @@ def _get_all(name: str, base_type: str) -> list[str]:
ids.append(record['id'])
return ids
def get_nodes(name: str) -> list[str]:
return _get_all(name, _NODE)
def get_links(name: str) -> list[str]:
return _get_all(name, _LINK)
def get_curves(name: str) -> list[str]:
return _get_all(name, _CURVE)
def get_patterns(name: str) -> list[str]:
return _get_all(name, _PATTERN)
def add_node(name: str, node_type: str, id: str, x: float, y: float, table_sql: str, table_undo_sql: str) -> ChangeSet:
if is_node(name, id):
return
@@ -96,6 +114,7 @@ def add_node(name: str, node_type: str, id: str, x: float, y: float, table_sql:
change.add(node_type, id)
return change
def delete_node(name: str, node_type: str, id: str, table_sql: str, table_undo_sql: str) -> ChangeSet:
if not is_node(name, id):
return
@@ -122,3 +141,41 @@ def delete_node(name: str, node_type: str, id: str, table_sql: str, table_undo_s
change = ChangeSet()
change.delete(node_type, id)
return change
def add_link(name: str, link_type: str, id: str, table_sql: str, table_undo_sql: str) -> ChangeSet:
if is_link(name, id):
return
with conn[name].cursor() as cur:
sql = f"insert into _link (id, type) values ('{id}', '{link_type}'); "
sql += table_sql
cur.execute(sql)
redo = sql.replace("'", '"')
undo = table_undo_sql
undo += f' delete from _link where id = "{id}";'
add_operation(name, redo, undo)
change = ChangeSet()
change.add(link_type, id)
return change
def delete_link(name: str, link_type: str, id: str, table_sql: str, table_undo_sql: str) -> ChangeSet:
if not is_node(name, id):
return
with conn[name].cursor(row_factory=dict_row) as cur:
sql = table_sql
sql += f" delete from _link where id = '{id}';"
cur.execute(sql)
redo = sql.replace("'", '"')
undo = f'insert into _link (id, type) values ("{id}", "{link_type}"); '
undo += table_undo_sql
add_operation(name, redo, undo)
change = ChangeSet()
change.delete(link_type, id)
return change