Skip to content

Commit 74fd97c

Browse files
ENH: Add metric utilities and scoring functions (#32)
* ENH: Add metrics/utils.py with metric registry and helpers * REF: Move greater_is_better from metrics.py to utils.py * ENH: Add load_metric_as_scorer and compute_metric functions * ENH: Expose utility functions in metrics/__init__.py * TST: Add tests for metric utilities and scorers
1 parent e11084a commit 74fd97c

File tree

5 files changed

+388
-62
lines changed

5 files changed

+388
-62
lines changed

orca_python/metrics/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
ccr,
77
gm,
88
gmsec,
9-
greater_is_better,
109
mae,
1110
mmae,
1211
ms,
@@ -16,9 +15,14 @@
1615
tkendall,
1716
wkappa,
1817
)
18+
from .utils import (
19+
compute_metric,
20+
get_metric_names,
21+
greater_is_better,
22+
load_metric_as_scorer,
23+
)
1924

2025
__all__ = [
21-
"greater_is_better",
2226
"ccr",
2327
"amae",
2428
"gm",
@@ -32,4 +36,8 @@
3236
"spearman",
3337
"rps",
3438
"accuracy_off1",
39+
"get_metric_names",
40+
"greater_is_better",
41+
"load_metric_as_scorer",
42+
"compute_metric",
3543
]

orca_python/metrics/metrics.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,48 +7,6 @@
77
from sklearn.metrics import confusion_matrix, recall_score
88

99

10-
def greater_is_better(metric_name):
11-
"""Determine if greater values indicate better classification performance.
12-
13-
Needed when declaring a new scorer through make_scorer from sklearn.
14-
15-
Parameters
16-
----------
17-
metric_name : str
18-
Name of the metric.
19-
20-
Returns
21-
-------
22-
greater_is_better : bool
23-
True if greater values indicate better classification performance, False otherwise.
24-
25-
Examples
26-
--------
27-
>>> from orca_python.metrics.metrics import greater_is_better
28-
>>> greater_is_better("ccr")
29-
True
30-
>>> greater_is_better("mze")
31-
False
32-
>>> greater_is_better("mae")
33-
False
34-
35-
"""
36-
greater_is_better_metrics = [
37-
"ccr",
38-
"ms",
39-
"gm",
40-
"gmsec",
41-
"tkendall",
42-
"wkappa",
43-
"spearman",
44-
"accuracy_off1",
45-
]
46-
if metric_name in greater_is_better_metrics:
47-
return True
48-
else:
49-
return False
50-
51-
5210
def ccr(y_true, y_pred):
5311
"""Calculate the Correctly Classified Ratio.
5412

orca_python/metrics/tests/test_metrics.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
ccr,
1111
gm,
1212
gmsec,
13-
greater_is_better,
1413
mae,
1514
mmae,
1615
ms,
@@ -22,23 +21,6 @@
2221
)
2322

2423

25-
def test_greater_is_better():
26-
"""Test the greater_is_better function."""
27-
assert greater_is_better("accuracy_off1")
28-
assert greater_is_better("ccr")
29-
assert greater_is_better("gm")
30-
assert greater_is_better("gmsec")
31-
assert not greater_is_better("mae")
32-
assert not greater_is_better("mmae")
33-
assert not greater_is_better("amae")
34-
assert greater_is_better("ms")
35-
assert not greater_is_better("mze")
36-
assert not greater_is_better("rps")
37-
assert greater_is_better("tkendall")
38-
assert greater_is_better("wkappa")
39-
assert greater_is_better("spearman")
40-
41-
4224
def test_accuracy_off1():
4325
"""Test the Accuracy that allows errors in adjacent classes."""
4426
y_true = np.array([0, 1, 2, 3, 4, 5])
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""Tests for the metrics module utilities."""
2+
3+
import numpy.testing as npt
4+
import pytest
5+
6+
from orca_python.metrics import (
7+
accuracy_off1,
8+
amae,
9+
ccr,
10+
gm,
11+
gmsec,
12+
mae,
13+
mmae,
14+
ms,
15+
mze,
16+
rps,
17+
spearman,
18+
tkendall,
19+
wkappa,
20+
)
21+
from orca_python.metrics.utils import (
22+
_METRICS,
23+
compute_metric,
24+
get_metric_names,
25+
greater_is_better,
26+
load_metric_as_scorer,
27+
)
28+
29+
30+
def test_get_metric_names():
31+
"""Test that get_metric_names returns all available metric names."""
32+
all_metrics = get_metric_names()
33+
expected_names = list(_METRICS.keys())
34+
35+
assert type(all_metrics) is list
36+
assert all_metrics[:3] == ["accuracy_off1", "amae", "ccr"]
37+
assert "rps" in all_metrics
38+
npt.assert_array_equal(sorted(all_metrics), sorted(expected_names))
39+
40+
41+
@pytest.mark.parametrize(
42+
"metric_name, gib",
43+
[
44+
("accuracy_off1", True),
45+
("amae", False),
46+
("ccr", True),
47+
("gm", True),
48+
("gmsec", True),
49+
("mae", False),
50+
("mmae", False),
51+
("ms", True),
52+
("mze", False),
53+
("rps", False),
54+
("spearman", True),
55+
("tkendall", True),
56+
("wkappa", True),
57+
],
58+
)
59+
def test_greater_is_better(metric_name, gib):
60+
"""Test that greater_is_better returns the correct boolean for each metric."""
61+
assert greater_is_better(metric_name) == gib
62+
63+
64+
def test_greater_is_better_invalid_name():
65+
"""Test that greater_is_better raises an error for an invalid metric name."""
66+
error_msg = "Unrecognized metric name: 'roc_auc'."
67+
68+
with pytest.raises(KeyError, match=error_msg):
69+
greater_is_better("roc_auc")
70+
71+
72+
@pytest.mark.parametrize(
73+
"metric_name, metric",
74+
[
75+
("rps", rps),
76+
("ccr", ccr),
77+
("accuracy_off1", accuracy_off1),
78+
("gm", gm),
79+
("gmsec", gmsec),
80+
("mae", mae),
81+
("mmae", mmae),
82+
("amae", amae),
83+
("ms", ms),
84+
("mze", mze),
85+
("tkendall", tkendall),
86+
("wkappa", wkappa),
87+
("spearman", spearman),
88+
],
89+
)
90+
def test_load_metric_as_scorer(metric_name, metric):
91+
"""Test that load_metric_as_scorer correctly loads the expected metric."""
92+
metric_func = load_metric_as_scorer(metric_name)
93+
94+
assert metric_func._score_func == metric
95+
assert metric_func._sign == (1 if greater_is_better(metric_name) else -1)
96+
97+
98+
@pytest.mark.parametrize(
99+
"metric_name, metric",
100+
[
101+
("ccr", ccr),
102+
("accuracy_off1", accuracy_off1),
103+
("gm", gm),
104+
("gmsec", gmsec),
105+
("mae", mae),
106+
("mmae", mmae),
107+
("amae", amae),
108+
("ms", ms),
109+
("mze", mze),
110+
("tkendall", tkendall),
111+
("wkappa", wkappa),
112+
("spearman", spearman),
113+
],
114+
)
115+
def test_correct_metric_output(metric_name, metric):
116+
"""Test that the loaded metric function produces the same output as the
117+
original metric."""
118+
y_true = [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
119+
y_pred = [1, 3, 3, 1, 2, 3, 1, 2, 2, 1, 3, 1, 1, 2, 2, 2, 3, 3, 1, 3]
120+
metric_func = load_metric_as_scorer(metric_name)
121+
metric_true = metric(y_true, y_pred)
122+
metric_pred = metric_func._score_func(y_true, y_pred)
123+
124+
npt.assert_almost_equal(metric_pred, metric_true, decimal=6)
125+
126+
127+
def test_load_metric_invalid_name():
128+
"""Test that loading an invalid metric raises the correct exception."""
129+
error_msg = "metric_name must be a string."
130+
with pytest.raises(TypeError, match=error_msg):
131+
load_metric_as_scorer(123)
132+
133+
error_msg = "Unrecognized metric name: 'roc_auc'."
134+
with pytest.raises(KeyError, match=error_msg):
135+
load_metric_as_scorer("roc_auc")
136+
137+
138+
@pytest.mark.parametrize(
139+
"metric_name",
140+
[
141+
"ccr",
142+
"accuracy_off1",
143+
"gm",
144+
"gmsec",
145+
"mae",
146+
"mmae",
147+
"amae",
148+
"ms",
149+
"mze",
150+
"tkendall",
151+
"wkappa",
152+
"spearman",
153+
],
154+
)
155+
def test_compute_metric(metric_name) -> None:
156+
"""Test that compute_metric returns the correct metric value."""
157+
y_true = [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3]
158+
y_pred = [1, 3, 3, 1, 2, 3, 1, 2, 2, 1, 3, 1, 1, 2, 2, 2, 3, 3, 1, 3]
159+
metric_value = compute_metric(metric_name, y_true, y_pred)
160+
metric_func = load_metric_as_scorer(metric_name)
161+
metric_true = metric_func._score_func(y_true, y_pred)
162+
163+
npt.assert_almost_equal(metric_value, metric_true, decimal=6)
164+
165+
166+
def test_compute_metric_invalid_name():
167+
"""Test that compute_metric raises an error for an invalid metric name."""
168+
error_msg = "Unrecognized metric name: 'roc_auc'."
169+
170+
with pytest.raises(KeyError, match=error_msg):
171+
compute_metric("roc_auc", [1, 2, 3], [1, 2, 3])

0 commit comments

Comments
 (0)