Adding mincdd comparison and better debug information

This commit is contained in:
Michael Tryby
2018-07-10 18:50:33 -04:00
parent 51f5295e8c
commit f38a637aaf
4 changed files with 84 additions and 25 deletions

View File

@@ -38,7 +38,7 @@ __status = "Development"
def epanet_allclose_compare(path_test, path_ref, rtol, atol):
'''
Compares results in two EPANET binary files. Using the comparison criteria
Compares results in two EPANET binary files using the comparison criteria
described in the numpy assert_allclose documentation.
(test_value - ref_value) <= atol + rtol * abs(ref_value)
@@ -67,22 +67,67 @@ def epanet_allclose_compare(path_test, path_ref, rtol, atol):
for (test, ref) in it.izip(ordr.output_generator(path_test),
ordr.output_generator(path_ref)):
if len(test) != len(ref):
if len(test[0]) != len(ref[0]):
raise ValueError('Inconsistent lengths')
# Skip over arrays that are equal
if np.array_equal(test, ref):
if np.array_equal(test[0], ref[0]):
continue
else:
np.testing.assert_allclose(test, ref, rtol, atol)
np.testing.assert_allclose(test[0], ref[0], rtol, atol)
return True
# def epanet_better_compare(path_test, path_ref, rtol, atol):
# '''
# If you don't like assert_allclose you can add another function here.
# '''
# pass
def epanet_mincdd_compare(path_test, path_ref, rtol, atol):
'''
Compares the results of two EPANET binary files using a correct decimal
digits (cdd) comparison criteria:
min cdd(test, ref) >= atol
Returns true if min cdd in the file is greater than or equal to atol,
otherwise an AssertionError is thrown.
Arguments:
path_test - path to result file being testedgit
path_ref - path to reference result file
rtol - ignored
atol - minimum allowable cdd value (i.e. 3)
Returns:
True
Raises:
ValueError()
AssertionError()
'''
min_cdd = 100.0
for (test, ref) in it.izip(ordr.output_generator(path_test),
ordr.output_generator(path_ref)):
if len(test[0]) != len(ref[0]):
raise ValueError('Inconsistent lengths')
# Skip over arrays that are equal
if np.array_equal(test[0], ref[0]):
continue
else:
diff = np.fabs(np.subtract(test[0], ref[0]))
idx = np.unravel_index(np.argmax(diff), diff.shape)
if diff[idx] != 0.0:
tmp = - np.log10(diff[idx])
if tmp < min_cdd:
min_cdd = tmp;
if np.floor(min_cdd) >= atol:
return True
else:
raise AssertionError('min_cdd=%d less than atol=%g' % (min_cdd, atol))
def epanet_report_compare(path_test, path_ref, rtol, atol):
'''