From 555cab36276d6319f5c8891e182c20462a4af81b Mon Sep 17 00:00:00 2001 From: wqy Date: Sat, 3 Sep 2022 09:30:57 +0800 Subject: [PATCH] Support transaction to commit or rollback, this will modify operation tree --- api/__init__.py | 1 + api/operation.py | 73 ++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/api/__init__.py b/api/__init__.py index 44a98b8..0854a57 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -5,6 +5,7 @@ from .project import copy_project from .operation import execute_undo as undo from .operation import execute_redo as redo from .operation import take_snapshot, pick_snapshot +from .operation import have_transaction, start_transaction, commit_transaction, rollback_transaction from .s0_base import is_node, is_junction, is_reservoir, is_tank from .s0_base import is_link, is_pipe, is_pump, is_valve diff --git a/api/operation.py b/api/operation.py index 53d1f45..b2d8ccd 100644 --- a/api/operation.py +++ b/api/operation.py @@ -1,6 +1,24 @@ -from psycopg.rows import dict_row +from psycopg.rows import dict_row, Row from .connection import g_conn_dict as conn +def _get_current_transaction(name: str) -> Row | None: + with conn[name].cursor(row_factory=dict_row) as cur: + cur.execute(f"select id from transaction_operation") + return cur.fetchone() + +def _get_current_transaction_id(name: str) -> int: + row = _get_current_transaction(name) + return int(row['id']) + +def _remove_transaction(name: str) -> None: + with conn[name].cursor() as cur: + cur.execute(f"delete from transaction_operation") + +def _remove_operation(name: str, id: int) -> None: + with conn[name].cursor(row_factory=dict_row) as cur: + cur.execute(f"delete from operation where id = {id}") + return int(cur.fetchone()['id']) + def _get_parents(name: str, id: int) -> list[int]: ids = [id] with conn[name].cursor(row_factory=dict_row) as cur: @@ -42,7 +60,7 @@ def _query_redo(name: str, id: str) -> dict[str, str]: cur.execute(f"select redo from operation where id = {id}") return cur.fetchone()['redo'] -def _set_redo_child(name: str, id: str, child: str) -> None: +def _set_redo_child(name: str, id: int, child: int | str) -> None: with conn[name].cursor() as cur: cur.execute(f"update operation set redo_child = {child} where id = {id}") @@ -56,8 +74,19 @@ def add_operation(name: str, redo: str, undo: str) -> None: old = _get_current_operation(name) _update_current_operation(name, old, curr) -def execute_undo(name: str) -> None: +def execute_undo(name: str, discard: bool = False) -> None: curr = _get_current_operation(name) + + # transaction control + tran = _get_current_transaction(name) + if int(tran['id']) >= 0: + if bool(tran['strict']): # strict mode disallow undo + print("Do not allow to undo in strict transaction mode!") + return + elif tran <= curr: # normal mode disallow undo start point + print("Do not allow to undo transaction start point!") + return + row = _query_undo(name, curr) undo = row['undo'] if undo == '': @@ -65,11 +94,14 @@ def execute_undo(name: str) -> None: return parent = int(row['parent']) - _set_redo_child(name, parent, curr) + _set_redo_child(name, parent, 'NULL' if discard else curr) _execute(name, undo) _update_current_operation(name, curr, parent) + if discard: + _remove_operation(name, curr) + def execute_redo(name: str) -> None: curr = _get_current_operation(name) redoChild = _query_redo_child(name, curr) @@ -141,11 +173,36 @@ def pick_snapshot(name: str, tag: str) -> None: # can not undo a committed transaction or redo a rollback transaction. # it may remove snapshot tag if snapshot in a rollback transaction -def start_transaction(name: str) -> None: - pass +def have_transaction(name: str) -> bool: + return _get_current_transaction_id(name) >= 0 + +def start_transaction(name: str, strict: bool = False) -> None: + if have_transaction(name): + print("Only support single transaction now, please commit/rollback current transaction!") + return + + curr = _get_current_operation(name) + + with conn[name].cursor() as cur: + cur.execute(f"insert into transaction_operation (id, strict) values ({curr}, {strict});") def commit_transaction(name: str) -> None: - pass + if not have_transaction(name): + print("No active transaction!") + return + + _remove_transaction(name) def rollback_transaction(name: str) -> None: - pass + tran = _get_current_transaction_id(name) + if tran >= 0: + print("No active transaction!") + return + + curr = _get_current_operation(name) + curr_parents = _get_parents(name, curr) + + for i in range(curr_parents.index(tran)): + execute_undo(name, True) + + _remove_transaction(name)