Skip to content

Commit ae20415

Browse files
committed
add back params
1 parent 5bd880c commit ae20415

File tree

4 files changed

+124
-107
lines changed

4 files changed

+124
-107
lines changed

tests/test_output_plots.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,46 +20,37 @@
2020
base_tpi = utils.safe_read_pickle(
2121
os.path.join(CUR_PATH, "test_io_data", "TPI_vars_baseline.pkl")
2222
)
23-
if sys.version_info[1] == 11:
24-
base_params = utils.safe_read_pickle(
25-
os.path.join(
26-
CUR_PATH, "test_io_data", "model_params_baseline_v311.pkl"
27-
)
28-
)
29-
elif sys.version_info[1] == 12:
30-
base_params = utils.safe_read_pickle(
31-
os.path.join(
32-
CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl"
33-
)
34-
)
35-
elif sys.version_info[1] == 13:
36-
base_params = utils.safe_read_pickle(
37-
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline.pkl")
38-
)
39-
else:
40-
# Raise assertion error
41-
assert False, "Unsupported Python version"
23+
4224
reform_ss = utils.safe_read_pickle(
4325
os.path.join(CUR_PATH, "test_io_data", "SS_vars_reform.pkl")
4426
)
4527
reform_tpi = utils.safe_read_pickle(
4628
os.path.join(CUR_PATH, "test_io_data", "TPI_vars_reform.pkl")
4729
)
48-
if sys.version_info[1] == 11:
49-
reform_params = utils.safe_read_pickle(
50-
os.path.join(CUR_PATH, "test_io_data", "model_params_reform_v311.pkl")
51-
)
52-
elif sys.version_info[1] == 12:
53-
reform_params = utils.safe_read_pickle(
54-
os.path.join(CUR_PATH, "test_io_data", "model_params_reform_v312.pkl")
55-
)
56-
elif sys.version_info[1] == 13:
57-
reform_params = utils.safe_read_pickle(
58-
os.path.join(CUR_PATH, "test_io_data", "model_params_reform.pkl")
59-
)
60-
else:
61-
# Raise assertion error
62-
assert False, "Unsupported Python version"
30+
base_params = Specifications()
31+
base_params.update_specifications(
32+
{
33+
"M": 3,
34+
"gamma": [0.5, 0.5, 0.5],
35+
"gamma_g": [0.0, 0.0, 0.0],
36+
"epsilon": [0.5, 0.5, 0.5],
37+
"I": 3,
38+
"alpha_c": [0.3, 0.4, 0.3],
39+
"io_matrix": np.eye(3),
40+
}
41+
)
42+
reform_params = Specifications()
43+
reform_params.update_specifications(
44+
{
45+
"M": 3,
46+
"gamma": [0.5, 0.5, 0.5],
47+
"gamma_g": [0.0, 0.0, 0.0],
48+
"epsilon": [0.5, 0.5, 0.5],
49+
"I": 3,
50+
"alpha_c": [0.3, 0.4, 0.3],
51+
"io_matrix": np.eye(3),
52+
}
53+
)
6354
reform_taxfunctions = utils.safe_read_pickle(
6455
os.path.join(CUR_PATH, "test_io_data", "TxFuncEst_reform.pkl")
6556
)

tests/test_output_tables.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,46 +19,37 @@
1919
base_tpi = utils.safe_read_pickle(
2020
os.path.join(CUR_PATH, "test_io_data", "TPI_vars_baseline.pkl")
2121
)
22-
if sys.version_info[1] == 11:
23-
base_params = utils.safe_read_pickle(
24-
os.path.join(
25-
CUR_PATH, "test_io_data", "model_params_baseline_v311.pkl"
26-
)
27-
)
28-
elif sys.version_info[1] == 12:
29-
base_params = utils.safe_read_pickle(
30-
os.path.join(
31-
CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl"
32-
)
33-
)
34-
elif sys.version_info[1] == 13:
35-
base_params = utils.safe_read_pickle(
36-
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline.pkl")
37-
)
38-
else:
39-
# Raise assertion error
40-
assert False, "Unsupported Python version"
22+
4123
reform_ss = utils.safe_read_pickle(
4224
os.path.join(CUR_PATH, "test_io_data", "SS_vars_reform.pkl")
4325
)
4426
reform_tpi = utils.safe_read_pickle(
4527
os.path.join(CUR_PATH, "test_io_data", "TPI_vars_reform.pkl")
4628
)
47-
if sys.version_info[1] == 11:
48-
reform_params = utils.safe_read_pickle(
49-
os.path.join(CUR_PATH, "test_io_data", "model_params_reform_v311.pkl")
50-
)
51-
elif sys.version_info[1] == 12:
52-
reform_params = utils.safe_read_pickle(
53-
os.path.join(CUR_PATH, "test_io_data", "model_params_reform_v312.pkl")
54-
)
55-
elif sys.version_info[1] == 13:
56-
reform_params = utils.safe_read_pickle(
57-
os.path.join(CUR_PATH, "test_io_data", "model_params_reform.pkl")
58-
)
59-
else:
60-
# Raise assertion error
61-
assert False, "Unsupported Python version"
29+
base_params = Specifications()
30+
base_params.update_specifications(
31+
{
32+
"M": 3,
33+
"gamma": [0.5, 0.5, 0.5],
34+
"gamma_g": [0.0, 0.0, 0.0],
35+
"epsilon": [0.5, 0.5, 0.5],
36+
"I": 3,
37+
"alpha_c": [0.3, 0.4, 0.3],
38+
"io_matrix": np.eye(3),
39+
}
40+
)
41+
reform_params = Specifications()
42+
reform_params.update_specifications(
43+
{
44+
"M": 3,
45+
"gamma": [0.5, 0.5, 0.5],
46+
"gamma_g": [0.0, 0.0, 0.0],
47+
"epsilon": [0.5, 0.5, 0.5],
48+
"I": 3,
49+
"alpha_c": [0.3, 0.4, 0.3],
50+
"io_matrix": np.eye(3),
51+
}
52+
)
6253
# add investment tax credit parameter that not in cached parameters
6354
base_params.inv_tax_credit = np.zeros(
6455
(base_params.T + base_params.S, base_params.M)

tests/test_parameter_plots.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,43 @@
1414

1515
# Load in test results and parameters
1616
CUR_PATH = os.path.abspath(os.path.dirname(__file__))
17-
if sys.version_info[1] == 11:
18-
base_params = utils.safe_read_pickle(
19-
os.path.join(
20-
CUR_PATH, "test_io_data", "model_params_baseline_v311.pkl"
21-
)
22-
)
23-
elif sys.version_info[1] == 12:
24-
base_params = utils.safe_read_pickle(
25-
os.path.join(
26-
CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl"
27-
)
28-
)
29-
elif sys.version_info[1] == 13:
30-
base_params = utils.safe_read_pickle(
31-
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline.pkl")
32-
)
33-
else:
34-
# Raise assertion error
35-
assert False, "Unsupported Python version"
17+
base_ss = utils.safe_read_pickle(
18+
os.path.join(CUR_PATH, "test_io_data", "SS_vars_baseline.pkl")
19+
)
20+
base_tpi = utils.safe_read_pickle(
21+
os.path.join(CUR_PATH, "test_io_data", "TPI_vars_baseline.pkl")
22+
)
23+
24+
reform_ss = utils.safe_read_pickle(
25+
os.path.join(CUR_PATH, "test_io_data", "SS_vars_reform.pkl")
26+
)
27+
reform_tpi = utils.safe_read_pickle(
28+
os.path.join(CUR_PATH, "test_io_data", "TPI_vars_reform.pkl")
29+
)
30+
base_params = Specifications()
31+
base_params.update_specifications(
32+
{
33+
"M": 3,
34+
"gamma": [0.5, 0.5, 0.5],
35+
"gamma_g": [0.0, 0.0, 0.0],
36+
"epsilon": [0.5, 0.5, 0.5],
37+
"I": 3,
38+
"alpha_c": [0.3, 0.4, 0.3],
39+
"io_matrix": np.eye(3),
40+
}
41+
)
42+
reform_params = Specifications()
43+
reform_params.update_specifications(
44+
{
45+
"M": 3,
46+
"gamma": [0.5, 0.5, 0.5],
47+
"gamma_g": [0.0, 0.0, 0.0],
48+
"epsilon": [0.5, 0.5, 0.5],
49+
"I": 3,
50+
"alpha_c": [0.3, 0.4, 0.3],
51+
"io_matrix": np.eye(3),
52+
}
53+
)
3654

3755
base_taxfunctions = utils.safe_read_pickle(
3856
os.path.join(CUR_PATH, "test_io_data", "TxFuncEst_baseline.pkl")

tests/test_parameter_tables.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,53 @@
44

55
import pytest
66
import os
7-
import sys
7+
import numpy as np
88
from ogcore import utils, parameter_tables
99
from ogcore.parameters import Specifications
1010

1111

1212
# Load in test results and parameters
1313
CUR_PATH = os.path.abspath(os.path.dirname(__file__))
1414

15-
if sys.version_info[1] == 11:
16-
base_params = utils.safe_read_pickle(
17-
os.path.join(
18-
CUR_PATH, "test_io_data", "model_params_baseline_v311.pkl"
19-
)
20-
)
21-
elif sys.version_info[1] == 12:
22-
base_params = utils.safe_read_pickle(
23-
os.path.join(
24-
CUR_PATH, "test_io_data", "model_params_baseline_v312.pkl"
25-
)
26-
)
27-
elif sys.version_info[1] == 13:
28-
base_params = utils.safe_read_pickle(
29-
os.path.join(CUR_PATH, "test_io_data", "model_params_baseline.pkl")
30-
)
31-
else:
32-
# Raise assertion error
33-
assert False, "Unsupported Python version"
15+
base_ss = utils.safe_read_pickle(
16+
os.path.join(CUR_PATH, "test_io_data", "SS_vars_baseline.pkl")
17+
)
18+
base_tpi = utils.safe_read_pickle(
19+
os.path.join(CUR_PATH, "test_io_data", "TPI_vars_baseline.pkl")
20+
)
3421
base_taxfunctions = utils.safe_read_pickle(
3522
os.path.join(CUR_PATH, "test_io_data", "TxFuncEst_baseline.pkl")
3623
)
24+
reform_ss = utils.safe_read_pickle(
25+
os.path.join(CUR_PATH, "test_io_data", "SS_vars_reform.pkl")
26+
)
27+
reform_tpi = utils.safe_read_pickle(
28+
os.path.join(CUR_PATH, "test_io_data", "TPI_vars_reform.pkl")
29+
)
30+
base_params = Specifications()
31+
base_params.update_specifications(
32+
{
33+
"M": 3,
34+
"gamma": [0.5, 0.5, 0.5],
35+
"gamma_g": [0.0, 0.0, 0.0],
36+
"epsilon": [0.5, 0.5, 0.5],
37+
"I": 3,
38+
"alpha_c": [0.3, 0.4, 0.3],
39+
"io_matrix": np.eye(3),
40+
}
41+
)
42+
reform_params = Specifications()
43+
reform_params.update_specifications(
44+
{
45+
"M": 3,
46+
"gamma": [0.5, 0.5, 0.5],
47+
"gamma_g": [0.0, 0.0, 0.0],
48+
"epsilon": [0.5, 0.5, 0.5],
49+
"I": 3,
50+
"alpha_c": [0.3, 0.4, 0.3],
51+
"io_matrix": np.eye(3),
52+
}
53+
)
3754

3855

3956
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)