diff --git a/api/__init__.py b/api/__init__.py index b3e9439..5d348c6 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -3,6 +3,7 @@ from .project import is_project_open, open_project, close_project from .operation import execute_undo as undo from .operation import execute_redo as redo +from .operation import take_snapshot, pick_snapshot from .s0_base import is_node, is_junction, is_reservoir, is_tank from .s0_base import is_link, is_pipe, is_pump, is_valve diff --git a/api/operation.py b/api/operation.py index f0bb301..92c8d67 100644 --- a/api/operation.py +++ b/api/operation.py @@ -74,3 +74,79 @@ def execute_redo(name: str) -> None: _execute(name, redo) _update_current_operation(name, curr, child) + +# snapshot support to check out between different version of database +# snapshot is persistent +# since redo always remember the recently undo path + +def take_snapshot(name: str, tag: str) -> None: + if tag == None or tag == '': + print('Non empty tag is expected!') + return + + curr = _get_current_operation(name) + + with conn[name].cursor() as cur: + parent = _get_current_operation(name) + cur.execute(f"insert into snapshot_operation (id, tag) values ({curr}, '{tag}')") + + +def _get_parents(name: str, id: int) -> list[int]: + ids = [id] + with conn[name].cursor(row_factory=dict_row) as cur: + while ids[-1] != 0: + cur.execute(f"select parent from operation where id = {ids[-1]}") + ids.append(int(cur.fetchone()['parent'])) + return ids + +def pick_snapshot(name: str, tag: str) -> None: + if tag == None or tag == '': + print('Non empty tag is expected!') + return + + curr = _get_current_operation(name) + curr_parents = _get_parents(name, curr) + + with conn[name].cursor(row_factory=dict_row) as cur: + cur.execute(f"select id from snapshot_operation where tag = '{tag}'") + if cur.rowcount < 1: + print('No such snapshot!') + return + target = int(cur.fetchone()['id']) + if target in curr_parents: # target -> curr + for i in range(curr_parents.index(target)): + execute_undo(name) + else: + target_parents = _get_parents(name, target) + if curr in target_parents: # curr -> target + for i in range(target_parents.index(curr)): + execute_redo(name) + else: + ancestor_index = -1 + while curr_parents[ancestor_index] == target_parents[ancestor_index]: + ancestor_index -= 1 + + # ancestor -> curr + ancestor = curr_parents[ancestor_index + 1] # ancestor_index + 1 is common parent + for i in range(curr_parents.index(ancestor)): + execute_undo(name) + # ancestor -> redo, need assign redo_child + while target_parents[ancestor_index] != target: + cur.execute(f"update operation set redo_child = '{target_parents[ancestor_index]}' where id = '{target_parents[ancestor_index + 1]}'") + execute_redo(name) + ancestor_index -= 1 + cur.execute(f"update operation set redo_child = '{target}' where id = '{target_parents[1]}'") + execute_redo(name) + +# transaction is volatile, commit/rollback will destroy transaction. +# can not undo a committed transaction or redo a rollback transaction. +# it may remove snapshot tag if snapshot in a rollback transaction + +def start_transaction(name: str) -> None: + pass + +def commit_transaction(name: str) -> None: + pass + +def rollback_transaction(name: str) -> None: + pass diff --git a/new_demo.py b/new_demo.py index 553c9d4..5974d8f 100644 --- a/new_demo.py +++ b/new_demo.py @@ -1,5 +1,40 @@ from tjnetwork_new import * +def demo_snapshot(): + p = "demo_snapshot" + + if is_project_open(p): + close_project(p) + + if have_project(p): + delete_project(p) + + create_project(p) + open_project(p) + + add_junction(p, 'j-1', 10.0, 20.0, 30.0) + add_junction(p, 'j-2', 10.0, 20.0, 30.0) + add_junction(p, 'j-3', 10.0, 20.0, 30.0) + add_junction(p, 'j-4', 10.0, 20.0, 30.0) + take_snapshot(p, "1-2-3-4") + + undo(p) + undo(p) + undo(p) + undo(p) + + add_junction(p, 'j-5', 10.0, 20.0, 30.0) + add_junction(p, 'j-6', 10.0, 20.0, 30.0) + add_junction(p, 'j-7', 10.0, 20.0, 30.0) + add_junction(p, 'j-8', 10.0, 20.0, 30.0) + take_snapshot(p, "5-6-7-8") + + pick_snapshot(p, "1-2-3-4") + + close_project(p) + # delete_project(p) + + def demo_1_title(): p = "demo_1_title" @@ -27,7 +62,8 @@ def demo_1_title(): print(get_title(p)) # close_project(p) - delete_project(p) + # delete_project(p) + def demo_2_junctions(): p = "demo_2_junctions" @@ -77,9 +113,15 @@ def demo_2_junctions(): undo(p) print(get_junction_coord(p, j)) # {'x': 10.0, 'y': 20.0} + redo(p) + print(get_junction_coord(p, j)) # {'x': 100.0, 'y': 200.0} + close_project(p) - delete_project(p) + # delete_project(p) + if __name__ == "__main__": - demo_1_title() - demo_2_junctions() + demo_snapshot() + # demo_1_title() + # demo_2_junctions() + pass diff --git a/tjnetwork_new.py b/tjnetwork_new.py index 73375fa..4e8632e 100644 --- a/tjnetwork_new.py +++ b/tjnetwork_new.py @@ -46,6 +46,12 @@ def undo(name: str) -> None: def redo(name: str) -> None: return api.redo(name) +def take_snapshot(name: str, tag: str) -> None: + return api.take_snapshot(name, tag) + +def pick_snapshot(name: str, tag: str) -> None: + return api.pick_snapshot(name, tag) + ############################################################ # type