diff --git a/api/operation.py b/api/operation.py index 57c2fd9..cedb82c 100644 --- a/api/operation.py +++ b/api/operation.py @@ -3,7 +3,7 @@ from .connection import g_conn_dict as conn def _get_current_transaction(name: str) -> Row | None: with conn[name].cursor(row_factory=dict_row) as cur: - cur.execute(f"select id from transaction_operation") + cur.execute(f"select * from transaction_operation") return cur.fetchone() def _get_current_transaction_id(name: str) -> int: @@ -20,7 +20,6 @@ def _remove_operation(name: str, id: int) -> None: cur.execute(f"delete from transaction_operation where id = {id}") # this should not happen cur.execute(f"delete from snapshot_operation where id = {id}") # this may happen cur.execute(f"delete from operation where id = {id}") - return int(cur.fetchone()['id']) def _get_parents(name: str, id: int) -> list[int]: ids = [id] @@ -81,14 +80,15 @@ def execute_undo(name: str, discard: bool = False) -> None: curr = _get_current_operation(name) # transaction control - tran = _get_current_transaction(name) - if int(tran['id']) >= 0: - if bool(tran['strict']): # strict mode disallow undo - print("Do not allow to undo in strict transaction mode!") - return - elif tran <= curr: # normal mode disallow undo start point, and there is foreign key constraint - print("Do not allow to undo transaction start point!") - return + if have_transaction(name): + tran = _get_current_transaction(name) + if tran != None and int(tran['id']) >= 0: + if bool(tran['strict']): # strict mode disallow undo + print("Do not allow to undo in strict transaction mode!") + return + elif int(tran['id']) >= curr: # normal mode disallow undo start point, and there is foreign key constraint + print("Do not allow to undo transaction start point!") + return row = _query_undo(name, curr) undo = row['undo'] @@ -183,7 +183,9 @@ def pick_snapshot(name: str, tag: str) -> None: # it may remove snapshot tag if snapshot in a rollback transaction def have_transaction(name: str) -> bool: - return _get_current_transaction_id(name) >= 0 + with conn[name].cursor(row_factory=dict_row) as cur: + cur.execute(f"select * from transaction_operation") + return cur.rowcount > 0 def start_transaction(name: str, strict: bool = False) -> None: if have_transaction(name): @@ -203,11 +205,12 @@ def commit_transaction(name: str) -> None: _remove_transaction(name) def abort_transaction(name: str) -> None: - tran = _get_current_transaction_id(name) - if tran >= 0: + if not have_transaction(name): print("No active transaction!") return + tran = _get_current_transaction_id(name) + curr = _get_current_operation(name) curr_parents = _get_parents(name, curr) diff --git a/new_demo.py b/new_demo.py index 5974d8f..d14cc57 100644 --- a/new_demo.py +++ b/new_demo.py @@ -1,5 +1,6 @@ from tjnetwork_new import * + def demo_snapshot(): p = "demo_snapshot" @@ -13,8 +14,8 @@ def demo_snapshot(): open_project(p) add_junction(p, 'j-1', 10.0, 20.0, 30.0) - add_junction(p, 'j-2', 10.0, 20.0, 30.0) - add_junction(p, 'j-3', 10.0, 20.0, 30.0) + add_junction(p, 'j-2', 10.0, 20.0, 30.0) + add_junction(p, 'j-3', 10.0, 20.0, 30.0) add_junction(p, 'j-4', 10.0, 20.0, 30.0) take_snapshot(p, "1-2-3-4") @@ -35,6 +36,41 @@ def demo_snapshot(): # delete_project(p) +def demo_transaction(): + p = "demo_transaction" + + if is_project_open(p): + close_project(p) + + if have_project(p): + delete_project(p) + + create_project(p) + open_project(p) + + add_junction(p, 'j-1', 10.0, 20.0, 30.0) + take_snapshot(p, "1") + add_junction(p, 'j-2', 10.0, 20.0, 30.0) + take_snapshot(p, "2") + + start_transaction(p) + + add_junction(p, 'j-3', 10.0, 20.0, 30.0) + take_snapshot(p, "3") + add_junction(p, 'j-4', 10.0, 20.0, 30.0) + take_snapshot(p, "4") + + abort_transaction(p) + + print(have_snapshot(p, "1")) + print(have_snapshot(p, "2")) + print(have_snapshot(p, "3")) + print(have_snapshot(p, "4")) + + close_project(p) + # delete_project(p) + + def demo_1_title(): p = "demo_1_title" @@ -121,7 +157,8 @@ def demo_2_junctions(): if __name__ == "__main__": - demo_snapshot() + # demo_snapshot() + demo_transaction() # demo_1_title() # demo_2_junctions() pass