diff --git a/api/__init__.py b/api/__init__.py index b8ac9e9..e31759f 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -11,7 +11,7 @@ from .operation import execute_undo, execute_redo from .operation import have_snapshot, take_snapshot, pick_snapshot from .operation import pick_operation, sync_with_server -from .command import execute_batch_commands +from .command import execute_batch_command, execute_batch_commands from .s0_base import JUNCTION, RESERVOIR, TANK, PIPE, PUMP, VALVE, PATTERN, CURVE from .s0_base import is_node, is_junction, is_reservoir, is_tank diff --git a/api/command.py b/api/command.py index 791e147..d74eab3 100644 --- a/api/command.py +++ b/api/command.py @@ -103,3 +103,118 @@ def execute_batch_commands(name: str, cs: ChangeSet) -> ChangeSet: pass return result + + +def cache_add_command(name: str, cs: ChangeSet) -> SqlChangeSet | None: + type = cs.operations[0]['type'] + + if type == JUNCTION: + return add_junction_cache(name, cs) + elif type == RESERVOIR: + return add_reservoir_cache(name, cs) + elif type == TANK: + return add_tank_cache(name, cs) + elif type == PIPE: + return add_pipe_cache(name, cs) + elif type == PUMP: + return add_pump_cache(name, cs) + elif type == VALVE: + return add_valve_cache(name, cs) + + return None + + +def cache_update_command(name: str, cs: ChangeSet) -> SqlChangeSet | None: + type = cs.operations[0]['type'] + + if type == 'title': + return set_title_cache(name, cs) + if type == JUNCTION: + return set_junction_cache(name, cs) + elif type == RESERVOIR: + return set_reservoir_cache(name, cs) + elif type == TANK: + return set_tank_cache(name, cs) + elif type == PIPE: + return set_pipe_cache(name, cs) + elif type == PUMP: + return set_pump_cache(name, cs) + elif type == VALVE: + return set_valve_cache(name, cs) + elif type == 'demand': + return set_demand_cache(name, cs) + elif type == 'status': + return set_status_cache(name, cs) + elif type == PATTERN: + return set_pattern_cache(name, cs) + elif type == CURVE: + return set_curve_cache(name, cs) + elif type == 'emitter': + return set_emitter_cache(name, cs) + elif type == 'time': + return set_time_cache(name, cs) + elif type == 'option': + return set_option_cache(name, cs) + + return None + + +def cache_delete_command(name: str, cs: ChangeSet) -> SqlChangeSet | None: + type = cs.operations[0]['type'] + + if type == JUNCTION: + return delete_junction_cache(name, cs) + elif type == RESERVOIR: + return delete_reservoir_cache(name, cs) + elif type == TANK: + return delete_tank_cache(name, cs) + elif type == PIPE: + return delete_pipe_cache(name, cs) + elif type == PUMP: + return delete_pump_cache(name, cs) + elif type == VALVE: + return delete_valve_cache(name, cs) + + return None + + +def execute_batch_command(name: str, cs: ChangeSet) -> ChangeSet: + redo_sql_s = [] + undo_sql_s = [] + redo_cs_s = [] + undo_cs_s = [] + + try: + for op in cs.operations: + operation = op['operation'] + + r = None + + if operation == API_ADD: + r = cache_add_command(name, ChangeSet(op)) + elif operation == API_UPDATE: + r = cache_update_command(name, ChangeSet(op)) + elif operation == API_DELETE: + r = cache_delete_command(name, ChangeSet(op)) + + if r == None: + return ChangeSet() + + redo_sql_s.append(r.redo_sql) + undo_sql_s.append(r.undo_sql) + redo_cs_s.append(r.redo_cs) + undo_cs_s.append(r.undo_cs) + except: + pass + + redo_sql = '\n'.join(redo_sql_s) + + undo_sql_s.reverse() + undo_sql = '\n'.join(undo_sql_s) + + undo_cs_s.reverse() + + try: + return execute_batch(name, redo_sql, undo_sql, redo_cs_s, undo_cs_s) + except: + return ChangeSet() diff --git a/api/operation.py b/api/operation.py index 75777e4..7fe031b 100644 --- a/api/operation.py +++ b/api/operation.py @@ -96,6 +96,26 @@ def execute_command(name: str, command: SqlChangeSet) -> ChangeSet: return ChangeSet(command.redo_cs) +def execute_batch(name: str, redo_sql: str, undo_sql: str, redo_cs_s: list[dict[str, Any]], undo_cs_s: list[dict[str, Any]]) -> ChangeSet: + write(name, redo_sql) + + parent = get_current_operation(name) + redo_sql = redo_sql.replace("'", "''") + undo_sql = undo_sql.replace("'", "''") + redo_cs_str = str(redo_cs_s).replace("'", "''") + undo_cs_str = str(undo_cs_s).replace("'", "''") + write(name, f"insert into operation (id, redo, undo, parent, redo_cs, undo_cs) values (default, '{redo_sql}', '{undo_sql}', {parent}, '{redo_cs_str}', '{undo_cs_str}')") + + current = read(name, 'select max(id) as id from operation')['id'] + write(name, f"update current_operation set id = {current}") + + cs = ChangeSet() + for r_cs in redo_cs_s: + cs.append(r_cs) + + return cs + + def execute_undo(name: str, discard: bool = False) -> ChangeSet: row = read(name, f'select * from operation where id = {get_current_operation(name)}') @@ -112,7 +132,15 @@ def execute_undo(name: str, discard: bool = False) -> ChangeSet: else: write(name, f"update operation set redo_child = {row['id']} where id = {row['parent']}") - return ChangeSet(eval(row['undo_cs'])) + e = eval(row['undo_cs']) + if isinstance(e, type({})): + return ChangeSet(e) + + cs = ChangeSet() + for _cs in e: + cs.append(_cs) + + return cs def execute_redo(name: str) -> ChangeSet: @@ -125,7 +153,15 @@ def execute_redo(name: str) -> ChangeSet: write(name, f"update current_operation set id = {row['id']} where id = {row['parent']}") - return ChangeSet(eval(row['redo_cs'])) + e = eval(row['redo_cs']) + if isinstance(e, type({})): + return ChangeSet(e) + + cs = ChangeSet() + for _cs in e: + cs.append(_cs) + + return cs def have_snapshot(name: str, tag: str) -> bool: @@ -195,9 +231,30 @@ def pick_snapshot(name: str, tag: str, discard: bool) -> ChangeSet: return pick_operation(name, target, discard) -def _get_change_set(name: str, operation: int, undo: bool) -> dict[str, Any]: +def _get_change_set(name: str, operation: int, undo: bool) -> ChangeSet: row = read(name, f'select * from operation where id = {operation}') - return eval(row['undo_cs']) if undo else eval(row['redo_cs']) + + cs = ChangeSet() + if undo: + e = eval(row['undo_cs']) + if isinstance(e, type({})): + return ChangeSet(e) + + cs = ChangeSet() + for _cs in e: + cs.append(_cs) + + return cs + else: + e = eval(row['redo_cs']) + if isinstance(e, type({})): + return ChangeSet(e) + + cs = ChangeSet() + for _cs in e: + cs.append(_cs) + + return cs def sync_with_server(name: str, operation: int) -> ChangeSet: @@ -212,13 +269,13 @@ def sync_with_server(name: str, operation: int) -> ChangeSet: if fr in to_parents: index = to_parents.index(fr) - 1 while index >= 0: - change.append(_get_change_set(name, to_parents[index], False)) #redo + change.merge(_get_change_set(name, to_parents[index], False)) #redo index -= 1 elif to in fr_parents: index = 0 while index <= fr_parents.index(to) - 1: - change.append(_get_change_set(name, fr_parents[index], True)) + change.merge(_get_change_set(name, fr_parents[index], True)) index += 1 else: @@ -230,12 +287,12 @@ def sync_with_server(name: str, operation: int) -> ChangeSet: index = 0 while index <= fr_parents.index(ancestor) - 1: - change.append(_get_change_set(name, fr_parents[index], True)) + change.merge(_get_change_set(name, fr_parents[index], True)) index += 1 index = to_parents.index(ancestor) - 1 while index >= 0: - change.append(_get_change_set(name, to_parents[index], False)) + change.merge(_get_change_set(name, to_parents[index], False)) index -= 1 return change.compress() diff --git a/test_tjnetwork.py b/test_tjnetwork.py index c5d4b1a..918f547 100644 --- a/test_tjnetwork.py +++ b/test_tjnetwork.py @@ -1834,5 +1834,44 @@ class TestApi: self.leave(p) + def test_batch_command(self): + p = 'test_batch_command' + self.enter(p) + + cs = ChangeSet() + cs.add({'type': JUNCTION, 'id': 'j1', 'x': 0.0, 'y': 10.0, 'elevation': 20.0}) + cs.add({'type': JUNCTION, 'id': 'j2', 'x': 0.0, 'y': 10.0, 'elevation': 20.0}) + cs.add({'type': JUNCTION, 'id': 'j2', 'x': 0.0, 'y': 10.0, 'elevation': 20.0}) # fail + + cs = execute_batch_command(p, cs) + assert len(cs.operations) == 0 + + assert get_current_operation(p) == 0 + + cs = ChangeSet() + cs.add({'type': JUNCTION, 'id': 'j1', 'x': 0.0, 'y': 10.0, 'elevation': 20.0}) + cs.add({'type': JUNCTION, 'id': 'j2', 'x': 0.0, 'y': 10.0, 'elevation': 20.0}) + + cs = execute_batch_command(p, cs) + + assert get_current_operation(p) == 1 + + cs = ChangeSet() + cs.delete({'type': JUNCTION, 'id': 'j1'}) + cs.delete({'type': JUNCTION, 'id': 'j2'}) + + cs = execute_batch_command(p, cs) + + assert get_current_operation(p) == 2 + + cs = execute_undo(p) + assert get_current_operation(p) == 1 + + cs = execute_undo(p) + assert get_current_operation(p) == 0 + + self.leave(p) + + if __name__ == '__main__': pytest.main() diff --git a/tjnetwork.py b/tjnetwork.py index 11fef7d..9f88c8e 100644 --- a/tjnetwork.py +++ b/tjnetwork.py @@ -137,6 +137,9 @@ def pick_operation(name: str, operation: int, discard: bool = False) -> ChangeSe def sync_with_server(name: str, operation: int) -> ChangeSet: return api.sync_with_server(name, operation) +def execute_batch_command(name: str, cs: ChangeSet) -> ChangeSet: + return api.execute_batch_command(name, cs) + def execute_batch_commands(name: str, cs: ChangeSet) -> ChangeSet: return api.execute_batch_commands(name, cs)