Skip to content

Commit ceb47e1

Browse files
BUG: AMAE and MMAE correction (#24)
* BUG: Fix AMAE and MMAE for missing classes in y_true * TST: Extend AMAE and MMAE tests
1 parent f9d8081 commit ceb47e1

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

orca_python/metrics/metrics.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,13 @@ def amae(y_true, y_pred):
121121
costs = np.reshape(np.tile(range(n_class), n_class), (n_class, n_class))
122122
costs = np.abs(costs - np.transpose(costs))
123123
errors = costs * cm
124+
125+
# Remove rows with all zeros in the confusion matrix
126+
non_zero_cm_rows = ~np.all(cm == 0, axis=1)
127+
errors = errors[non_zero_cm_rows]
128+
cm = cm[non_zero_cm_rows]
129+
124130
per_class_maes = np.sum(errors, axis=1) / np.sum(cm, axis=1).astype("double")
125-
per_class_maes = per_class_maes[~np.isnan(per_class_maes)]
126131
return np.mean(per_class_maes)
127132

128133

@@ -249,8 +254,13 @@ def mmae(y_true, y_pred):
249254
costs = np.reshape(np.tile(range(n_class), n_class), (n_class, n_class))
250255
costs = np.abs(costs - np.transpose(costs))
251256
errors = costs * cm
257+
258+
# Remove rows with all zeros in the confusion matrix
259+
non_zero_cm_rows = ~np.all(cm == 0, axis=1)
260+
errors = errors[non_zero_cm_rows]
261+
cm = cm[non_zero_cm_rows]
262+
252263
per_class_maes = np.sum(errors, axis=1) / np.sum(cm, axis=1).astype("double")
253-
per_class_maes = per_class_maes[~np.isnan(per_class_maes)]
254264
return per_class_maes.max()
255265

256266

orca_python/metrics/tests/test_metrics.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def test_amae():
8181
actual = amae(y_true, y_pred)
8282
npt.assert_almost_equal(expected, actual, decimal=6)
8383

84+
y_true = np.array([0, 1, 2, 3, 3])
85+
y_pred = np.array([0, 1, 2, 3, 4])
86+
expected = 0.125
87+
actual = amae(y_true, y_pred)
88+
npt.assert_almost_equal(expected, actual, decimal=6)
89+
8490

8591
def test_gm():
8692
"""Test the Geometric Mean (GM) metric."""
@@ -147,6 +153,12 @@ def test_mmae():
147153
actual = mmae(y_true, y_pred)
148154
npt.assert_almost_equal(expected, actual, decimal=6)
149155

156+
y_true = np.array([0, 1, 2, 3, 3])
157+
y_pred = np.array([0, 1, 2, 3, 4])
158+
expected = 0.5
159+
actual = mmae(y_true, y_pred)
160+
npt.assert_almost_equal(expected, actual, decimal=6)
161+
150162

151163
def test_ms():
152164
"""Test the Mean Sensitivity (MS) metric."""

0 commit comments

Comments
 (0)