diff --git a/api/__init__.py b/api/__init__.py index 0854a57..9ea241c 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -4,7 +4,7 @@ from .project import copy_project from .operation import execute_undo as undo from .operation import execute_redo as redo -from .operation import take_snapshot, pick_snapshot +from .operation import have_snapshot, take_snapshot, pick_snapshot from .operation import have_transaction, start_transaction, commit_transaction, rollback_transaction from .s0_base import is_node, is_junction, is_reservoir, is_tank diff --git a/api/operation.py b/api/operation.py index b2d8ccd..969362f 100644 --- a/api/operation.py +++ b/api/operation.py @@ -119,6 +119,11 @@ def execute_redo(name: str) -> None: # snapshot is persistent # since redo always remember the recently undo path +def have_snapshot(name: str, tag: str) -> bool: + with conn[name].cursor(row_factory=dict_row) as cur: + cur.execute(f"select id from snapshot_operation where tag = '{tag}'") + return cur.rowcount > 0 + def take_snapshot(name: str, tag: str) -> None: if tag == None or tag == '': print('Non empty tag is expected!') @@ -127,7 +132,6 @@ def take_snapshot(name: str, tag: str) -> None: curr = _get_current_operation(name) with conn[name].cursor() as cur: - parent = _get_current_operation(name) cur.execute(f"insert into snapshot_operation (id, tag) values ({curr}, '{tag}')") def pick_snapshot(name: str, tag: str) -> None: @@ -135,14 +139,15 @@ def pick_snapshot(name: str, tag: str) -> None: print('Non empty tag is expected!') return + if not have_snapshot(name, tag): + print('No such snapshot!') + return + curr = _get_current_operation(name) curr_parents = _get_parents(name, curr) with conn[name].cursor(row_factory=dict_row) as cur: cur.execute(f"select id from snapshot_operation where tag = '{tag}'") - if cur.rowcount < 1: - print('No such snapshot!') - return target = int(cur.fetchone()['id']) if target in curr_parents: # target -> curr for i in range(curr_parents.index(target)): diff --git a/tjnetwork_new.py b/tjnetwork_new.py index 52cb0e5..005292a 100644 --- a/tjnetwork_new.py +++ b/tjnetwork_new.py @@ -49,6 +49,9 @@ def undo(name: str) -> None: def redo(name: str) -> None: return api.redo(name) +def have_snapshot(name: str, tag: str) -> None: + return api.have_snapshot(name, tag) + def take_snapshot(name: str, tag: str) -> None: return api.take_snapshot(name, tag)