Skip to content

Commit 3013880

Browse files
TST: Add unit tests for scalers
1 parent 3b756df commit 3013880

File tree

2 files changed

+174
-0
lines changed

2 files changed

+174
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Tests for the preprocessing module."""
2+
3+
__all__ = []
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""Tests for the scaling functions in the preprocessing module."""
2+
3+
import numpy as np
4+
import numpy.testing as npt
5+
import pytest
6+
7+
from orca_python.preprocessing import apply_scaling, minmax_scale, standardize
8+
from orca_python.preprocessing.scalers import _validate_and_align
9+
from orca_python.testing import TEST_RANDOM_STATE
10+
11+
12+
@pytest.fixture
13+
def X_train():
14+
"""Create synthetic training data for testing."""
15+
return np.random.RandomState(TEST_RANDOM_STATE).randn(100, 5)
16+
17+
18+
@pytest.fixture
19+
def X_test():
20+
"""Create synthetic test data for testing."""
21+
return np.random.RandomState(TEST_RANDOM_STATE).randn(50, 5)
22+
23+
24+
def test_validate_and_align_valid_inputs(X_train, X_test):
25+
"""Test _validate_and_align with valid matching inputs."""
26+
X_train_valid, X_test_valid = _validate_and_align(X_train, X_test)
27+
28+
assert X_train_valid.shape == X_train.shape
29+
assert X_test_valid.shape == X_test.shape
30+
npt.assert_array_equal(X_train_valid, X_train)
31+
npt.assert_array_equal(X_test_valid, X_test)
32+
33+
34+
def test_validate_and_align_none_test(X_train):
35+
"""Test _validate_and_align with None test data."""
36+
X_train_valid, X_test_valid = _validate_and_align(X_train, None)
37+
38+
assert X_train_valid.shape == X_train.shape
39+
assert X_test_valid is None
40+
npt.assert_array_equal(X_train_valid, X_train)
41+
42+
43+
def test_validate_and_align_mismatched_features(X_train, X_test):
44+
"""Test _validate_and_align raises error for mismatched feature counts."""
45+
X_invalid = X_test[:, :-1]
46+
47+
error_msg = "X_test has 4 features but X_train has 5."
48+
with pytest.raises(ValueError, match=error_msg):
49+
_validate_and_align(X_train, X_invalid)
50+
51+
52+
def test_validate_and_align_invalid_input():
53+
"""Test _validate_and_align raises error for invalid input types."""
54+
with pytest.raises((ValueError, TypeError)):
55+
_validate_and_align("invalid", None)
56+
57+
58+
def test_minmax_scale_data(X_train, X_test):
59+
"""Test that minmax_scale function correctly scales input data to [0,1] range."""
60+
X_train_scaled, X_test_scaled = minmax_scale(X_train, X_test)
61+
62+
assert np.all(X_train_scaled >= 0) and np.all(X_train_scaled <= 1)
63+
assert np.all(X_test_scaled >= 0) and np.all(X_test_scaled <= 1)
64+
65+
66+
def test_minmax_scale_return_transformer(X_train, X_test):
67+
"""Test that minmax_scale returns transformer when requested."""
68+
_, expected_X_test, scaler = minmax_scale(X_train, X_test, return_transformer=True)
69+
70+
X_test_scaled = scaler.transform(X_test)
71+
npt.assert_array_almost_equal(X_test_scaled, expected_X_test)
72+
73+
74+
def test_standardize_data(X_train, X_test):
75+
"""Test that standardize function correctly produces output with zero mean
76+
and unit variance."""
77+
X_train_scaled, _ = standardize(X_train, X_test)
78+
79+
npt.assert_almost_equal(np.mean(X_train_scaled), 0, decimal=6)
80+
npt.assert_almost_equal(np.std(X_train_scaled), 1, decimal=6)
81+
82+
83+
def test_standardize_return_transformer(X_train, X_test):
84+
"""Test that standardize returns transformer when requested."""
85+
_, expected_X_test, scaler = standardize(X_train, X_test, return_transformer=True)
86+
87+
X_test_scaled = scaler.transform(X_test)
88+
npt.assert_array_almost_equal(X_test_scaled, expected_X_test)
89+
90+
91+
@pytest.mark.parametrize(
92+
"method, scaling_func", [("norm", minmax_scale), ("std", standardize)]
93+
)
94+
def test_apply_scaling_correctly(X_train, X_test, method, scaling_func):
95+
"""Test that different preprocessing methods work as expected."""
96+
expected_X_train, expected_X_test = scaling_func(X_train, X_test)
97+
X_train_scaled, X_test_scaled = apply_scaling(X_train, X_test, method)
98+
99+
npt.assert_array_almost_equal(X_train_scaled, expected_X_train)
100+
npt.assert_array_almost_equal(X_test_scaled, expected_X_test)
101+
102+
103+
def test_apply_scaling_none_method(X_train, X_test):
104+
"""Test that scaling function handles None input correctly."""
105+
post_X_train, post_X_test = apply_scaling(X_train, X_test, None)
106+
107+
npt.assert_array_equal(post_X_train, X_train)
108+
npt.assert_array_equal(post_X_test, X_test)
109+
110+
111+
def test_apply_scaling_return_transformer(X_train, X_test):
112+
"""Test that the transformer returned by apply_scaling works as expected."""
113+
_, _, scaler = apply_scaling(X_train, X_test, "norm", return_transformer=True)
114+
X_test_scaled = scaler.transform(X_test)
115+
_, expected_X_test = minmax_scale(X_train, X_test)
116+
npt.assert_array_almost_equal(X_test_scaled, expected_X_test)
117+
118+
_, _, scaler = apply_scaling(X_train, X_test, "std", return_transformer=True)
119+
X_test_scaled = scaler.transform(X_test)
120+
_, expected_X_test = standardize(X_train, X_test)
121+
npt.assert_array_almost_equal(X_test_scaled, expected_X_test)
122+
123+
_, _, scaler = apply_scaling(X_train, X_test, None, return_transformer=True)
124+
assert scaler is None
125+
126+
127+
def test_apply_scaling_case_insensitive(X_train, X_test):
128+
"""Test that apply_scaling handles different case variations."""
129+
X_train_lower, X_test_lower = apply_scaling(X_train, X_test, "norm")
130+
X_train_upper, X_test_upper = apply_scaling(X_train, X_test, "NORM")
131+
132+
npt.assert_array_equal(X_train_lower, X_train_upper)
133+
npt.assert_array_equal(X_test_lower, X_test_upper)
134+
135+
136+
def test_apply_scaling_invalid_method_type(X_train, X_test):
137+
"""Test that an invalid scaling method type raises a ValueError."""
138+
error_msg = "Scaling method must be a string or None."
139+
with pytest.raises(ValueError, match=error_msg):
140+
apply_scaling(X_train, X_test, 123)
141+
142+
143+
def test_apply_scaling_unknown_method(X_train, X_test):
144+
"""Test that an unknown scaling method raises a ValueError."""
145+
error_msg = "Unknown scaling method 'invalid'. Valid options: 'norm', 'std', None."
146+
with pytest.raises(ValueError, match=error_msg):
147+
apply_scaling(X_train, X_test, "invalid")
148+
149+
150+
def test_apply_scaling_inconsistent_features(X_train, X_test):
151+
"""Test that scaling with inconsistent feature dimensions raises ValueError."""
152+
X_invalid = X_test[:, :-1]
153+
154+
with pytest.raises(ValueError):
155+
apply_scaling(X_train, X_invalid, "norm")
156+
157+
158+
def test_minmax_scale_inconsistent_features(X_train, X_test):
159+
"""Test that minmax_scale raises ValueError for mismatched features."""
160+
X_invalid = X_test[:, :-1]
161+
162+
with pytest.raises(ValueError):
163+
minmax_scale(X_train, X_invalid)
164+
165+
166+
def test_standardize_inconsistent_features(X_train, X_test):
167+
"""Test that standardize raises ValueError for mismatched features."""
168+
X_invalid = X_test[:, :-1]
169+
170+
with pytest.raises(ValueError):
171+
standardize(X_train, X_invalid)

0 commit comments

Comments
 (0)