From 9e8a84c4c053e68ef3617ffb62a624766cc045cb Mon Sep 17 00:00:00 2001 From: "WQY\\qiong" Date: Sat, 22 Oct 2022 13:17:42 +0800 Subject: [PATCH] Add curve api and test --- api/__init__.py | 2 ++ api/command.py | 6 +++++ api/s11_patterns.py | 2 +- api/s12_curves.py | 48 +++++++++++++++++++++++++++++++++++ test_tjnetwork.py | 61 ++++++++++++++++++++++++++++++++++++++++++++- tjnetwork.py | 14 +++++++++++ 6 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 api/s12_curves.py diff --git a/api/__init__.py b/api/__init__.py index b5aa5bc..481cbf2 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -45,4 +45,6 @@ from .s10_status import get_status_schema, get_status, set_status from .s11_patterns import get_pattern_schema, get_pattern, set_pattern +from .s12_curves import get_curve_schema, get_curve, set_curve + from .s24_coordinates import get_node_coord diff --git a/api/command.py b/api/command.py index 37d06da..ffe7445 100644 --- a/api/command.py +++ b/api/command.py @@ -7,6 +7,8 @@ from .s6_pumps import * from .s7_valves import * from .s9_demands import * from .s10_status import * +from .s11_patterns import * +from .s12_curves import * def execute_add_command(name: str, cs: ChangeSet) -> ChangeSet: @@ -49,6 +51,10 @@ def execute_update_command(name: str, cs: ChangeSet) -> ChangeSet: return set_demand(name, cs) elif type == 'status': return set_status(name, cs) + elif type == PATTERN: + return set_pattern(name, cs) + elif type == CURVE: + return set_curve(name, cs) return ChangeSet() diff --git a/api/s11_patterns.py b/api/s11_patterns.py index c4cf382..dc4d996 100644 --- a/api/s11_patterns.py +++ b/api/s11_patterns.py @@ -30,7 +30,7 @@ def set_pattern(name: str, cs: ChangeSet) -> ChangeSet: for factor in cs.operations[0]['factors']: f_factor = float(factor) redo_sql += f"\ninsert into patterns (id, factor) values ({f_id}, {f_factor});" - new['factors'].append(f_factor) + new['factors'].append(factor) undo_sql = f"delete from patterns where id = {f_id};" undo_sql += f"\ndelete from _pattern where id = {f_id};" diff --git a/api/s12_curves.py b/api/s12_curves.py new file mode 100644 index 0000000..42452f7 --- /dev/null +++ b/api/s12_curves.py @@ -0,0 +1,48 @@ +from .operation import * +from .s0_base import * + + +def get_curve_schema(name: str) -> dict[str, dict[str, Any]]: + return { 'id' : {'type': 'str' , 'optional': False , 'readonly': True }, + 'coords' : {'type': 'list' , 'optional': False , 'readonly': False, + 'element': { 'x' : {'type': 'float' , 'optional': False , 'readonly': False }, + 'y' : {'type': 'float' , 'optional': False , 'readonly': False } }}} + + +def get_curve(name: str, id: str) -> dict[str, Any]: + cus = read_all(name, f"select * from curves where id = '{id}'") + cs = [] + for r in cus: + cs.append({ 'x': float(r['x']), 'y': float(r['y']) }) + return { 'id': id, 'coords': cs } + + +def set_curve(name: str, cs: ChangeSet) -> ChangeSet: + id = cs.operations[0]['id'] + + old = get_curve(name, id) + new = { 'id': id, 'coords': [] } + + f_id = f"'{id}'" + + # TODO: transaction ? + redo_sql = f"delete from curves where id = {f_id};" + redo_sql += f"\ndelete from _curve where id = {f_id};" + redo_sql += f"\ninsert into _curve (id) values ({f_id});" + for xy in cs.operations[0]['coords']: + x, y = float(xy['x']), float(xy['y']) + f_x, f_y = x, y + redo_sql += f"\ninsert into curves (id, x, y) values ({f_id}, {f_x}, {f_y});" + new['coords'].append({ 'x': x, 'y': y }) + + undo_sql = f"delete from curves where id = {f_id};" + undo_sql += f"\ndelete from _curve where id = {f_id};" + undo_sql += f"\ninsert into _curve (id) values ({f_id});" + for xy in old['coords']: + f_x, f_y = xy['x'], xy['y'] + undo_sql += f"\ninsert into curves (id, x, y) values ({f_id}, {f_x}, {f_y});" + + redo_cs = g_update_prefix | { 'type': 'curve' } | new + undo_cs = g_update_prefix | { 'type': 'curve' } | old + + return execute_command(name, redo_sql, undo_sql, redo_cs, undo_cs) diff --git a/test_tjnetwork.py b/test_tjnetwork.py index a833670..b41876a 100644 --- a/test_tjnetwork.py +++ b/test_tjnetwork.py @@ -1313,7 +1313,7 @@ class TestApi: def test_pattern(self): - p = 'test_demand' + p = 'test_pattern' self.enter(p) assert is_pattern(p, 'p0') == False @@ -1356,6 +1356,65 @@ class TestApi: self.leave(p) + def test_curve(self): + p = 'test_curve' + self.enter(p) + + assert is_curve(p, 'c0') == False + c0 = get_curve(p, 'c0') + assert c0['id'] == 'c0' + assert c0['coords'] == [] + + set_curve(p, ChangeSet({'id' : 'c0', 'coords': [{'x': 1.0, 'y': 2.0}, {'x': 2.0, 'y': 1.0}]})) + + assert is_curve(p, 'c0') + c0 = get_curve(p, 'c0') + assert c0['id'] == 'c0' + xys = c0['coords'] + assert len(xys) == 2 + assert xys[0]['x'] == 1.0 + assert xys[0]['y'] == 2.0 + assert xys[1]['x'] == 2.0 + assert xys[1]['y'] == 1.0 + + self.leave(p) + + + def test_curve_op(self): + p = 'test_curve_op' + self.enter(p) + + cs = set_curve(p, ChangeSet({'id' : 'c0', 'coords': [{'x': 1.0, 'y': 2.0}, {'x': 2.0, 'y': 1.0}]})).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == CURVE + assert cs['id'] == 'c0' + xys = cs['coords'] + assert len(xys) == 2 + assert xys[0]['x'] == 1.0 + assert xys[0]['y'] == 2.0 + assert xys[1]['x'] == 2.0 + assert xys[1]['y'] == 1.0 + + cs = execute_undo(p).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == CURVE + assert cs['id'] == 'c0' + assert cs['coords'] == [] + + cs = execute_redo(p).operations[0] + assert cs['operation'] == API_UPDATE + assert cs['type'] == CURVE + assert cs['id'] == 'c0' + xys = cs['coords'] + assert len(xys) == 2 + assert xys[0]['x'] == 1.0 + assert xys[0]['y'] == 2.0 + assert xys[1]['x'] == 2.0 + assert xys[1]['y'] == 1.0 + + self.leave(p) + + def test_snapshot(self): p = "test_snapshot" self.enter(p) diff --git a/tjnetwork.py b/tjnetwork.py index f90ad93..4e10519 100644 --- a/tjnetwork.py +++ b/tjnetwork.py @@ -339,6 +339,20 @@ def set_pattern(name: str, cs: ChangeSet) -> ChangeSet: return api.set_pattern(name, cs) +############################################################ +# curve 11.[CURVES] +############################################################ + +def get_curve_schema(name: str) -> dict[str, dict[str, Any]]: + return api.get_curve_schema(name) + +def get_curve(name: str, id: str) -> dict[str, Any]: + return api.get_curve(name, id) + +def set_curve(name: str, cs: ChangeSet) -> ChangeSet: + return api.set_curve(name, cs) + + ############################################################ # coord 24.[COORDINATES] ############################################################