Skip to content

Commit a16693a

Browse files
authored
Add tests for CUDA device (#9)
1 parent a1acbe4 commit a16693a

File tree

2 files changed

+55
-15
lines changed

2 files changed

+55
-15
lines changed

tests/test_analytical_crps.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,57 @@
22
import torch
33
from torch.distributions import Normal, StudentT
44

5+
from tests.conftest import needs_cuda
56
from torch_crps import crps_analytical_naive_integral, crps_analytical_normal, crps_analytical_studentt
67

78

8-
def test_crps_analytical_normal_batched_smoke():
9+
@pytest.mark.parametrize(
10+
"use_cuda",
11+
[
12+
pytest.param(False, id="cpu"),
13+
pytest.param(True, marks=needs_cuda, id="cuda"),
14+
],
15+
)
16+
def test_crps_analytical_normal_batched_smoke(use_cuda: bool):
917
"""Test that analytical solution works with batched Normal distributions."""
1018
torch.manual_seed(0)
1119

1220
# Define a batch of 2 independent univariate Normal distributions.
13-
mu = torch.tensor([[0.0, 1.0], [2.0, 3.0], [-2.0, -3.0]])
14-
sigma = torch.tensor([[1.0, 0.5], [1.5, 2.0], [0.01, 0.01]])
21+
mu = torch.tensor([[0.0, 1.0], [2.0, 3.0], [-2.0, -3.0]], device="cuda" if use_cuda else "cpu")
22+
sigma = torch.tensor([[1.0, 0.5], [1.5, 2.0], [0.01, 0.01]], device="cuda" if use_cuda else "cpu")
1523
normal_dist = torch.distributions.Normal(loc=mu, scale=sigma)
1624

1725
# Define observed values for each distribution in the batch.
18-
y = torch.tensor([[0.5, 1.5], [2.5, 3.5], [-2.0, -3.0]])
26+
y = torch.tensor([[0.5, 1.5], [2.5, 3.5], [-2.0, -3.0]], device="cuda" if use_cuda else "cpu")
1927

2028
# Compute CRPS using the analytical method.
2129
crps_analytical = crps_analytical_normal(normal_dist, y)
2230

2331
# Simple sanity check: CRPS should be non-negative.
24-
assert torch.all(crps_analytical >= 0), "CRPS values should be non-negative."
2532
assert crps_analytical.shape == y.shape, "CRPS output shape should match input shape."
33+
assert crps_analytical.dtype in [torch.float32, torch.float64], "CRPS output dtype should be float."
34+
assert crps_analytical.device == y.device, "CRPS output device should match input device."
35+
assert torch.all(crps_analytical >= 0), "CRPS values should be non-negative."
2636

2737

28-
def test_crps_analytical_naive_integral_vs_analytical_normal():
38+
@pytest.mark.parametrize(
39+
"use_cuda",
40+
[
41+
pytest.param(False, id="cpu"),
42+
pytest.param(True, marks=needs_cuda, id="cuda"),
43+
],
44+
)
45+
def test_crps_analytical_naive_integral_vs_analytical_normal(use_cuda: bool):
2946
"""Test that naive integral method matches the analytical solution for Normal distributions."""
3047
torch.manual_seed(0)
3148

3249
# Define 4 independent univariate Normal distributions.
33-
mu = torch.tensor([0.0, 0.0, 3.0, -7.0])
34-
sigma = torch.tensor([1.0, 0.01, 1.5, 0.5])
50+
mu = torch.tensor([0.0, 0.0, 3.0, -7.0], device="cuda" if use_cuda else "cpu")
51+
sigma = torch.tensor([1.0, 0.01, 1.5, 0.5], device="cuda" if use_cuda else "cpu")
3552
normal_dist = torch.distributions.Normal(loc=mu, scale=sigma)
3653

3754
# Define observed values, one for each distribution.
38-
y = torch.tensor([0.5, 0.0, 4.5, -6.0])
55+
y = torch.tensor([0.5, 0.0, 4.5, -6.0], device="cuda" if use_cuda else "cpu")
3956

4057
# Compute CRPS.
4158
crps_naive = crps_analytical_naive_integral(normal_dist, y, x_min=-10, x_max=10, x_steps=10001)
@@ -50,20 +67,28 @@ def test_crps_analytical_naive_integral_vs_analytical_normal():
5067
assert torch.allclose(crps_naive, crps_analytical, atol=1e-3, rtol=5e-4), (
5168
f"CRPS values do not match: naive={crps_naive}, analytical={crps_analytical}"
5269
)
70+
assert crps_naive.device == crps_analytical.device == y.device, "CRPS output device should match input device."
5371

5472

55-
def test_crps_analytical_naive_integral_vs_analytical_studentt():
73+
@pytest.mark.parametrize(
74+
"use_cuda",
75+
[
76+
pytest.param(False, id="cpu"),
77+
pytest.param(True, marks=needs_cuda, id="cuda"),
78+
],
79+
)
80+
def test_crps_analytical_naive_integral_vs_analytical_studentt(use_cuda: bool):
5681
"""Test that naive integral method matches the analytical solution for StudentT distributions."""
5782
torch.manual_seed(0)
5883

5984
# Define 4 independent univariate StudentT distributions.
60-
df = torch.tensor([100.0, 3.0, 5.0, 5.0])
61-
mu = torch.tensor([0.0, 0.0, 3.0, -7.0])
62-
sigma = torch.tensor([1.0, 0.01, 1.5, 0.5])
85+
df = torch.tensor([100.0, 3.0, 5.0, 5.0], device="cuda" if use_cuda else "cpu")
86+
mu = torch.tensor([0.0, 0.0, 3.0, -7.0], device="cuda" if use_cuda else "cpu")
87+
sigma = torch.tensor([1.0, 0.01, 1.5, 0.5], device="cuda" if use_cuda else "cpu")
6388
studentt_dist = torch.distributions.StudentT(df=df, loc=mu, scale=sigma)
6489

6590
# Define observed values, one for each distribution.
66-
y = torch.tensor([0.5, 0.0, 4.5, -6.0])
91+
y = torch.tensor([0.5, 0.0, 4.5, -6.0], device="cuda" if use_cuda else "cpu")
6792

6893
# Compute CRPS.
6994
crps_naive = crps_analytical_naive_integral(studentt_dist, y, x_min=-10, x_max=10, x_steps=10001)
@@ -78,6 +103,7 @@ def test_crps_analytical_naive_integral_vs_analytical_studentt():
78103
assert torch.allclose(crps_naive, crps_analytical, atol=1e-3, rtol=5e-4), (
79104
f"CRPS values do not match: naive={crps_naive}, analytical={crps_analytical}"
80105
)
106+
assert crps_naive.device == crps_analytical.device == y.device, "CRPS output device should match input device."
81107

82108

83109
@pytest.mark.parametrize(

tests/test_ensemble_crps.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from _pytest.fixtures import FixtureRequest
66

7+
from tests.conftest import needs_cuda
78
from torch_crps import crps_ensemble, crps_ensemble_naive
89

910

@@ -14,16 +15,29 @@
1415
)
1516
@pytest.mark.parametrize("crps_fcn", [crps_ensemble_naive, crps_ensemble], ids=["naive", "default"])
1617
@pytest.mark.parametrize("biased", [True, False], ids=["biased", "unbiased"])
17-
def test_crps_ensemble_smoke(test_case_fixture_name: str, crps_fcn: Callable, biased: bool, request: FixtureRequest):
18+
@pytest.mark.parametrize(
19+
"use_cuda",
20+
[
21+
pytest.param(False, id="cpu"),
22+
pytest.param(True, marks=needs_cuda, id="cuda"),
23+
],
24+
)
25+
def test_crps_ensemble_smoke(
26+
test_case_fixture_name: str, crps_fcn: Callable, biased: bool, use_cuda: bool, request: FixtureRequest
27+
):
1828
"""Test that naive ensemble method yield."""
1929
test_case_fixture: dict = request.getfixturevalue(test_case_fixture_name)
2030
x, y, expected_shape = test_case_fixture["x"], test_case_fixture["y"], test_case_fixture["expected_shape"]
31+
if use_cuda:
32+
x, y = x.cuda(), y.cuda()
2133

2234
crps = crps_fcn(x, y, biased)
2335

2436
assert isinstance(crps, torch.Tensor)
2537
assert crps.shape == expected_shape, "The output shape is incorrect!"
2638
assert crps.dtype in [torch.float32, torch.float64], "The output dtype is not float!"
39+
assert crps.device == x.device, "The output device does not match the input device!"
40+
assert torch.all(crps >= 0), "CRPS values should be non-negative!"
2741

2842

2943
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)