diff --git a/api/__init__.py b/api/__init__.py index 47b16ae..a10ca2b 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -22,6 +22,8 @@ from .database import read, try_read, read_all, write from .batch_exe import execute_batch_commands, execute_batch_command +from .extension_data import get_all_extension_data_keys, get_all_extension_data, get_extension_data, set_extension_data + from .s0_base import JUNCTION, RESERVOIR, TANK, PIPE, PUMP, VALVE, PATTERN, CURVE from .s0_base import is_node, is_junction, is_reservoir, is_tank from .s0_base import is_link, is_pipe, is_pump, is_valve diff --git a/api/batch_exe.py b/api/batch_exe.py index 7f28c4e..4b82497 100644 --- a/api/batch_exe.py +++ b/api/batch_exe.py @@ -1,6 +1,7 @@ from typing import Any from .sections import * from .database import API_ADD, API_UPDATE, API_DELETE, ChangeSet, write, read, read_all, get_current_operation +from .extension_data import set_extension_data from .s1_title import set_title from .s2_junctions import set_junction, add_junction, delete_junction from .s3_reservoirs import set_reservoir, add_reservoir, delete_reservoir @@ -112,6 +113,8 @@ def _execute_add_command(name: str, cs: ChangeSet) -> ChangeSet: def _execute_update_command(name: str, cs: ChangeSet) -> ChangeSet: type = cs.operations[0]['type'] + if type == 'extension_data': + return set_extension_data(name, cs) if type == s1_title: return set_title(name, cs) if type == s2_junction: diff --git a/api/extension_data.py b/api/extension_data.py new file mode 100644 index 0000000..cfce4ea --- /dev/null +++ b/api/extension_data.py @@ -0,0 +1,62 @@ +from .database import * + + +def get_all_extension_data_keys(name: str) -> list[str]: + result: list[str] = [] + for row in read_all(name, 'select key from extension_data'): + result.append(row['key']) + return result + + +def get_all_extension_data(name: str) -> dict[str, Any]: + result: dict[str, Any] = {} + for row in read_all(name, 'select key, value from extension_data'): + result[row['key']] = row['value'] + return result + + +def get_extension_data(name: str, key: str) -> str | None: + if key == None or key == '': + return None + row = try_read(name, f"select value from extension_data where key = '{key}'") + if row == None: + return None + return row['value'] + + +def _set_extension_data(name: str, cs: ChangeSet) -> DbChangeSet: + op = cs.operations[0] + key, new_val = op['key'], op['value'] + + f_new_val = f"'{new_val}'" if new_val != None else 'null' + + old_val = get_extension_data(name, key) + f_old_val = f"'{old_val}'" if old_val != None else 'null' + + redo_sql = f"delete from extension_data where key = '{key}';" + if new_val != None: + redo_sql += f"insert into extension_data (key, value) values ('{key}', {f_new_val});" + + undo_sql = f"delete from extension_data where key = '{key}';" + if old_val != None: + undo_sql += f"insert into extension_data (key, value) values ('{key}', {f_old_val});" + + redo_cs = g_update_prefix | { 'type': 'extension_data', 'key': key, 'value': new_val } + undo_cs = g_update_prefix | { 'type': 'extension_data', 'key': key, 'value': old_val } + + return DbChangeSet(redo_sql, undo_sql, [redo_cs], [undo_cs]) + + +def set_extension_data(name: str, cs: ChangeSet) -> ChangeSet: + if len(cs.operations) != 1: + return ChangeSet() + + op = cs.operations[0] + if 'key' not in op or 'value' not in op: + return ChangeSet() + + key = op['key'] + if key == None or key == '': + return ChangeSet() + + return execute_command(name, _set_extension_data(name, cs)) diff --git a/create_template.py b/create_template.py index c0260c2..80f2126 100644 --- a/create_template.py +++ b/create_template.py @@ -34,11 +34,13 @@ sql_create = [ "script/sql/create/30.scada_device_data.sql", "script/sql/create/31.scada_element.sql", "script/sql/create/32.region.sql", + "script/sql/create/extension_data.sql", "script/sql/create/operation.sql" ] sql_drop = [ "script/sql/drop/operation.sql", + "script/sql/create/extension_data.sql", "script/sql/drop/32.region.sql", "script/sql/drop/31.scada_element.sql", "script/sql/drop/30.scada_device_data.sql", diff --git a/script/sql/create/extension_data.sql b/script/sql/create/extension_data.sql new file mode 100644 index 0000000..491cbaf --- /dev/null +++ b/script/sql/create/extension_data.sql @@ -0,0 +1,5 @@ +create table extension_data +( + key text primary key +, value text not null +); diff --git a/script/sql/drop/extension_data.sql b/script/sql/drop/extension_data.sql new file mode 100644 index 0000000..ba82d5f --- /dev/null +++ b/script/sql/drop/extension_data.sql @@ -0,0 +1 @@ +drop table if exists extension_data; diff --git a/test_tjnetwork.py b/test_tjnetwork.py index bdbf5b4..7a239f9 100644 --- a/test_tjnetwork.py +++ b/test_tjnetwork.py @@ -203,6 +203,81 @@ class TestApi: self.leave(p) + # extension_data + + + def test_extension_data(self): + p = 'test_extension_data' + self.enter(p) + + assert get_all_extension_data_keys(p) == [] + assert get_all_extension_data(p) == {} + assert get_extension_data(p, '') == None + + set_extension_data(p, ChangeSet({'key': 'key', 'value': None})) + assert get_extension_data(p, 'key') == None + + set_extension_data(p, ChangeSet({'key': 'key', 'value': ''})) + assert get_extension_data(p, 'key') == '' + + set_extension_data(p, ChangeSet({'key': 'key', 'value': 'value'})) + assert get_extension_data(p, 'key') == 'value' + + set_extension_data(p, ChangeSet({'key': 'key', 'value': 'val'})) + assert get_extension_data(p, 'key') == 'val' + + set_extension_data(p, ChangeSet({'key': 'key1', 'value': 'val1'})) + assert get_extension_data(p, 'key1') == 'val1' + + assert get_all_extension_data_keys(p) == ['key', 'key1'] + assert get_all_extension_data(p) == {'key': 'val', 'key1': 'val1'} + + self.leave(p) + + + def test_extension_data_op(self): + p = 'test_extension_data_op' + self.enter(p) + + cs = set_extension_data(p, ChangeSet({'key': 'key', 'value': ''})).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == 'extension_data' + assert cs['key'] == 'key' + assert cs['value'] == '' + + cs = execute_undo(p).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == 'extension_data' + assert cs['key'] == 'key' + assert cs['value'] == None + + cs = execute_redo(p).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == 'extension_data' + assert cs['key'] == 'key' + assert cs['value'] == '' + + cs = set_extension_data(p, ChangeSet({'key': 'key', 'value': 'value'})).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == 'extension_data' + assert cs['key'] == 'key' + assert cs['value'] == 'value' + + cs = execute_undo(p).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == 'extension_data' + assert cs['key'] == 'key' + assert cs['value'] == '' + + cs = execute_redo(p).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == 'extension_data' + assert cs['key'] == 'key' + assert cs['value'] == 'value' + + self.leave(p) + + # complex test diff --git a/tjnetwork.py b/tjnetwork.py index 2bdff0f..5181f72 100644 --- a/tjnetwork.py +++ b/tjnetwork.py @@ -321,6 +321,23 @@ def write(name: str, sql: str): return api.write(name, sql) +############################################################ +# extension_data +############################################################ + +def get_all_extension_data_keys(name: str) -> list[str]: + return api.get_all_extension_data_keys(name) + +def get_all_extension_data(name: str) -> dict[str, Any]: + return api.get_all_extension_data(name) + +def get_extension_data(name: str, key: str) -> str | None: + return api.get_extension_data(name, key) + +def set_extension_data(name: str, cs: ChangeSet) -> ChangeSet: + return api.set_extension_data(name, cs) + + ############################################################ # type ############################################################