From bb1d772eaac38acaaa876a822d7001411456182f Mon Sep 17 00:00:00 2001 From: "WQY\\qiong" Date: Sat, 22 Oct 2022 12:55:16 +0800 Subject: [PATCH] Add pattern api and test --- api/__init__.py | 4 ++- api/s0_base.py | 23 +++++++++------- api/s11_patterns.py | 44 +++++++++++++++++++++++++++++++ api/s9_demands.py | 1 + script/sql/create/11.patterns.sql | 2 +- test_tjnetwork.py | 44 +++++++++++++++++++++++++++++++ tjnetwork.py | 18 ++++++++++++- 7 files changed, 123 insertions(+), 13 deletions(-) create mode 100644 api/s11_patterns.py diff --git a/api/__init__.py b/api/__init__.py index 88d57d6..b5aa5bc 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -13,7 +13,7 @@ from .operation import sync_with_server from .command import execute_batch_commands -from .s0_base import JUNCTION, RESERVOIR, TANK, PIPE, PUMP, VALVE +from .s0_base import JUNCTION, RESERVOIR, TANK, PIPE, PUMP, VALVE, PATTERN, CURVE from .s0_base import is_node, is_junction, is_reservoir, is_tank from .s0_base import is_link, is_pipe, is_pump, is_valve from .s0_base import is_curve @@ -43,4 +43,6 @@ from .s9_demands import get_demand_schema, get_demand, set_demand from .s10_status import LINK_STATUS_OPEN, LINK_STATUS_CLOSED, LINK_STATUS_ACTIVE from .s10_status import get_status_schema, get_status, set_status +from .s11_patterns import get_pattern_schema, get_pattern, set_pattern + from .s24_coordinates import get_node_coord diff --git a/api/s0_base.py b/api/s0_base.py index 006373a..899f3fa 100644 --- a/api/s0_base.py +++ b/api/s0_base.py @@ -3,17 +3,20 @@ from .connection import g_conn_dict as conn from .operation import * -_NODE = "_node" -_LINK = "_link" -_CURVE = "_curve" -_PATTERN = "_pattern" +_NODE = '_node' +_LINK = '_link' +_CURVE = '_curve' +_PATTERN = '_pattern' -JUNCTION = "junction" -RESERVOIR = "reservoir" -TANK = "tank" -PIPE = "pipe" -PUMP = "pump" -VALVE = "valve" +JUNCTION = 'junction' +RESERVOIR = 'reservoir' +TANK = 'tank' +PIPE = 'pipe' +PUMP = 'pump' +VALVE = 'valve' + +PATTERN = 'pattern' +CURVE = 'curve' def _get_from(name: str, id: str, base_type: str) -> Row | None: diff --git a/api/s11_patterns.py b/api/s11_patterns.py new file mode 100644 index 0000000..c4cf382 --- /dev/null +++ b/api/s11_patterns.py @@ -0,0 +1,44 @@ +from .operation import * +from .s0_base import * + + +def get_pattern_schema(name: str) -> dict[str, dict[str, Any]]: + return { 'id' : {'type': 'str' , 'optional': False , 'readonly': True }, + 'factors' : {'type': 'float_list' , 'optional': False , 'readonly': False } } + + +def get_pattern(name: str, id: str) -> dict[str, Any]: + pas = read_all(name, f"select * from patterns where id = '{id}'") + ps = [] + for r in pas: + ps.append(float(r['factor'])) + return { 'id': id, 'factors': ps } + + +def set_pattern(name: str, cs: ChangeSet) -> ChangeSet: + id = cs.operations[0]['id'] + + old = get_pattern(name, id) + new = { 'id': id, 'factors': [] } + + f_id = f"'{id}'" + + # TODO: transaction ? + redo_sql = f"delete from patterns where id = {f_id};" + redo_sql += f"\ndelete from _pattern where id = {f_id};" + redo_sql += f"\ninsert into _pattern (id) values ({f_id});" + for factor in cs.operations[0]['factors']: + f_factor = float(factor) + redo_sql += f"\ninsert into patterns (id, factor) values ({f_id}, {f_factor});" + new['factors'].append(f_factor) + + undo_sql = f"delete from patterns where id = {f_id};" + undo_sql += f"\ndelete from _pattern where id = {f_id};" + undo_sql += f"\ninsert into _pattern (id) values ({f_id});" + for f_factor in old['factors']: + undo_sql += f"\ninsert into patterns (id, factor) values ({f_id}, {f_factor});" + + redo_cs = g_update_prefix | { 'type': 'pattern' } | new + undo_cs = g_update_prefix | { 'type': 'pattern' } | old + + return execute_command(name, redo_sql, undo_sql, redo_cs, undo_cs) diff --git a/api/s9_demands.py b/api/s9_demands.py index 01b798b..cb67d6d 100644 --- a/api/s9_demands.py +++ b/api/s9_demands.py @@ -29,6 +29,7 @@ def set_demand(name: str, cs: ChangeSet) -> ChangeSet: f_junction = f"'{junction}'" + # TODO: transaction ? redo_sql = f"delete from demands where junction = {f_junction};" for r in cs.operations[0]['demands']: demand = float(r['demand']) diff --git a/script/sql/create/11.patterns.sql b/script/sql/create/11.patterns.sql index d6aaa4e..8835fa2 100644 --- a/script/sql/create/11.patterns.sql +++ b/script/sql/create/11.patterns.sql @@ -3,5 +3,5 @@ create table patterns ( id varchar(32) references _pattern(id) not null -, multipliers numeric not null +, factor numeric not null ); diff --git a/test_tjnetwork.py b/test_tjnetwork.py index e7a66b8..a833670 100644 --- a/test_tjnetwork.py +++ b/test_tjnetwork.py @@ -1312,6 +1312,50 @@ class TestApi: self.leave(p) + def test_pattern(self): + p = 'test_demand' + self.enter(p) + + assert is_pattern(p, 'p0') == False + p0 = get_pattern(p, 'p0') + assert p0['id'] == 'p0' + assert p0['factors'] == [] + + set_pattern(p, ChangeSet({'id' : 'p0', 'factors': [1.0, 2.0, 3.0]})) + + assert is_pattern(p, 'p0') + p0 = get_pattern(p, 'p0') + assert p0['id'] == 'p0' + assert p0['factors'] == [1.0, 2.0, 3.0] + + self.leave(p) + + + def test_pattern_op(self): + p = 'test_pattern_op' + self.enter(p) + + cs = set_pattern(p, ChangeSet({'id' : 'p0', 'factors': [1.0, 2.0, 3.0]})).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == PATTERN + assert cs['id'] == 'p0' + assert cs['factors'] == [1.0, 2.0, 3.0] + + cs = execute_undo(p).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == PATTERN + assert cs['id'] == 'p0' + assert cs['factors'] == [] + + cs = execute_redo(p).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == PATTERN + assert cs['id'] == 'p0' + assert cs['factors'] == [1.0, 2.0, 3.0] + + self.leave(p) + + def test_snapshot(self): p = "test_snapshot" self.enter(p) diff --git a/tjnetwork.py b/tjnetwork.py index fc90d54..f90ad93 100644 --- a/tjnetwork.py +++ b/tjnetwork.py @@ -23,6 +23,8 @@ TANK = api.TANK PIPE = api.PIPE PUMP = api.PUMP VALVE = api.VALVE +PATTERN = api.PATTERN +CURVE = api.CURVE OVERFLOW_YES = api.OVERFLOW_YES OVERFLOW_NO = api.OVERFLOW_NO @@ -295,7 +297,7 @@ def delete_valve(name: str, cs: ChangeSet) -> ChangeSet: ############################################################ -# demands 9.[DEMANDS] +# demand 9.[DEMANDS] ############################################################ def get_demand_schema(name: str) -> dict[str, dict[str, Any]]: @@ -323,6 +325,20 @@ def set_status(name: str, cs: ChangeSet) -> ChangeSet: return api.set_status(name, cs) +############################################################ +# pattern 11.[PATTERNS] +############################################################ + +def get_pattern_schema(name: str) -> dict[str, dict[str, Any]]: + return api.get_pattern_schema(name) + +def get_pattern(name: str, id: str) -> dict[str, Any]: + return api.get_pattern(name, id) + +def set_pattern(name: str, cs: ChangeSet) -> ChangeSet: + return api.set_pattern(name, cs) + + ############################################################ # coord 24.[COORDINATES] ############################################################