22import torch
33from torch .distributions import Normal , StudentT
44
5+ from tests .conftest import needs_cuda
56from torch_crps import crps_analytical_naive_integral , crps_analytical_normal , crps_analytical_studentt
67
78
8- def test_crps_analytical_normal_batched_smoke ():
9+ @pytest .mark .parametrize (
10+ "use_cuda" ,
11+ [
12+ pytest .param (False , id = "cpu" ),
13+ pytest .param (True , marks = needs_cuda , id = "cuda" ),
14+ ],
15+ )
16+ def test_crps_analytical_normal_batched_smoke (use_cuda : bool ):
917 """Test that analytical solution works with batched Normal distributions."""
1018 torch .manual_seed (0 )
1119
1220 # Define a batch of 2 independent univariate Normal distributions.
13- mu = torch .tensor ([[0.0 , 1.0 ], [2.0 , 3.0 ], [- 2.0 , - 3.0 ]])
14- sigma = torch .tensor ([[1.0 , 0.5 ], [1.5 , 2.0 ], [0.01 , 0.01 ]])
21+ mu = torch .tensor ([[0.0 , 1.0 ], [2.0 , 3.0 ], [- 2.0 , - 3.0 ]], device = "cuda" if use_cuda else "cpu" )
22+ sigma = torch .tensor ([[1.0 , 0.5 ], [1.5 , 2.0 ], [0.01 , 0.01 ]], device = "cuda" if use_cuda else "cpu" )
1523 normal_dist = torch .distributions .Normal (loc = mu , scale = sigma )
1624
1725 # Define observed values for each distribution in the batch.
18- y = torch .tensor ([[0.5 , 1.5 ], [2.5 , 3.5 ], [- 2.0 , - 3.0 ]])
26+ y = torch .tensor ([[0.5 , 1.5 ], [2.5 , 3.5 ], [- 2.0 , - 3.0 ]], device = "cuda" if use_cuda else "cpu" )
1927
2028 # Compute CRPS using the analytical method.
2129 crps_analytical = crps_analytical_normal (normal_dist , y )
2230
2331 # Simple sanity check: CRPS should be non-negative.
24- assert torch .all (crps_analytical >= 0 ), "CRPS values should be non-negative."
2532 assert crps_analytical .shape == y .shape , "CRPS output shape should match input shape."
33+ assert crps_analytical .dtype in [torch .float32 , torch .float64 ], "CRPS output dtype should be float."
34+ assert crps_analytical .device == y .device , "CRPS output device should match input device."
35+ assert torch .all (crps_analytical >= 0 ), "CRPS values should be non-negative."
2636
2737
28- def test_crps_analytical_naive_integral_vs_analytical_normal ():
38+ @pytest .mark .parametrize (
39+ "use_cuda" ,
40+ [
41+ pytest .param (False , id = "cpu" ),
42+ pytest .param (True , marks = needs_cuda , id = "cuda" ),
43+ ],
44+ )
45+ def test_crps_analytical_naive_integral_vs_analytical_normal (use_cuda : bool ):
2946 """Test that naive integral method matches the analytical solution for Normal distributions."""
3047 torch .manual_seed (0 )
3148
3249 # Define 4 independent univariate Normal distributions.
33- mu = torch .tensor ([0.0 , 0.0 , 3.0 , - 7.0 ])
34- sigma = torch .tensor ([1.0 , 0.01 , 1.5 , 0.5 ])
50+ mu = torch .tensor ([0.0 , 0.0 , 3.0 , - 7.0 ], device = "cuda" if use_cuda else "cpu" )
51+ sigma = torch .tensor ([1.0 , 0.01 , 1.5 , 0.5 ], device = "cuda" if use_cuda else "cpu" )
3552 normal_dist = torch .distributions .Normal (loc = mu , scale = sigma )
3653
3754 # Define observed values, one for each distribution.
38- y = torch .tensor ([0.5 , 0.0 , 4.5 , - 6.0 ])
55+ y = torch .tensor ([0.5 , 0.0 , 4.5 , - 6.0 ], device = "cuda" if use_cuda else "cpu" )
3956
4057 # Compute CRPS.
4158 crps_naive = crps_analytical_naive_integral (normal_dist , y , x_min = - 10 , x_max = 10 , x_steps = 10001 )
@@ -50,20 +67,28 @@ def test_crps_analytical_naive_integral_vs_analytical_normal():
5067 assert torch .allclose (crps_naive , crps_analytical , atol = 1e-3 , rtol = 5e-4 ), (
5168 f"CRPS values do not match: naive={ crps_naive } , analytical={ crps_analytical } "
5269 )
70+ assert crps_naive .device == crps_analytical .device == y .device , "CRPS output device should match input device."
5371
5472
55- def test_crps_analytical_naive_integral_vs_analytical_studentt ():
73+ @pytest .mark .parametrize (
74+ "use_cuda" ,
75+ [
76+ pytest .param (False , id = "cpu" ),
77+ pytest .param (True , marks = needs_cuda , id = "cuda" ),
78+ ],
79+ )
80+ def test_crps_analytical_naive_integral_vs_analytical_studentt (use_cuda : bool ):
5681 """Test that naive integral method matches the analytical solution for StudentT distributions."""
5782 torch .manual_seed (0 )
5883
5984 # Define 4 independent univariate StudentT distributions.
60- df = torch .tensor ([100.0 , 3.0 , 5.0 , 5.0 ])
61- mu = torch .tensor ([0.0 , 0.0 , 3.0 , - 7.0 ])
62- sigma = torch .tensor ([1.0 , 0.01 , 1.5 , 0.5 ])
85+ df = torch .tensor ([100.0 , 3.0 , 5.0 , 5.0 ], device = "cuda" if use_cuda else "cpu" )
86+ mu = torch .tensor ([0.0 , 0.0 , 3.0 , - 7.0 ], device = "cuda" if use_cuda else "cpu" )
87+ sigma = torch .tensor ([1.0 , 0.01 , 1.5 , 0.5 ], device = "cuda" if use_cuda else "cpu" )
6388 studentt_dist = torch .distributions .StudentT (df = df , loc = mu , scale = sigma )
6489
6590 # Define observed values, one for each distribution.
66- y = torch .tensor ([0.5 , 0.0 , 4.5 , - 6.0 ])
91+ y = torch .tensor ([0.5 , 0.0 , 4.5 , - 6.0 ], device = "cuda" if use_cuda else "cpu" )
6792
6893 # Compute CRPS.
6994 crps_naive = crps_analytical_naive_integral (studentt_dist , y , x_min = - 10 , x_max = 10 , x_steps = 10001 )
@@ -78,6 +103,7 @@ def test_crps_analytical_naive_integral_vs_analytical_studentt():
78103 assert torch .allclose (crps_naive , crps_analytical , atol = 1e-3 , rtol = 5e-4 ), (
79104 f"CRPS values do not match: naive={ crps_naive } , analytical={ crps_analytical } "
80105 )
106+ assert crps_naive .device == crps_analytical .device == y .device , "CRPS output device should match input device."
81107
82108
83109@pytest .mark .parametrize (
0 commit comments