Skip to content

Commit b491d93

Browse files
implementing PR feedback
1 parent 3fa4f54 commit b491d93

File tree

6 files changed

+191
-83
lines changed

6 files changed

+191
-83
lines changed

docs/source/computational_implementation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ The Temoa model code is organized into clear, purpose-driven packages:
389389
* ``monte_carlo`` - :doc:`monte_carlo` (Uncertainty quantification)
390390
* ``myopic`` - Sequential decision making with limited foresight
391391
* ``single_vector_mga`` - Focused MGA on specific variables ([!] untested in v4.0)
392-
* ``stochastics`` - Stochastic programming capabilities
392+
* ``stochastics`` - :doc:`stochastics` (Stochastic programming capabilities)
393393

394394
* ``temoa._internal`` - Internal utilities (not part of public API)
395395

docs/source/stochastics.rst

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,33 @@ The stochastic configuration file defines the scenarios, their probabilities, an
6060
# Define perturbations for a specific scenario
6161
[[perturbations]]
6262
scenario = "low_cost"
63-
variable = "cost_variable"
64-
tech = "IMPHCO1"
65-
# The value is a multiplier applied to the base value in the database
63+
table = "cost_variable"
64+
# Filter specifies which rows in the table to perturb
65+
filter = { tech = "IMPHCO1" }
66+
# Action can be "multiply", "add", or "set" (defaults to "set")
67+
action = "multiply"
6668
value = 0.5
6769
6870
[[perturbations]]
6971
scenario = "high_cost"
70-
variable = "cost_variable"
71-
tech = "IMPHCO1"
72+
table = "cost_variable"
73+
filter = { tech = "IMPHCO1" }
74+
action = "multiply"
7275
value = 1.5
7376
7477
Perturbation Options
7578
^^^^^^^^^^^^^^^^^^^^
7679

77-
Currently, the following perturbations are supported:
80+
Currently, the following fields are required for each perturbation:
7881

79-
* **variable**: The Temoa parameter to perturb. Currently supports:
80-
* ``cost_variable``
81-
* ``cost_invest``
82-
* ``cost_fixed``
83-
* ``demand``
84-
* **tech**: The technology to which the perturbation applies (required for cost perturbations).
85-
* **commodity**: The commodity to which the perturbation applies (required for demand perturbations).
86-
* **value**: A multiplier applied to the base value found in the input database. For example, a value of ``1.5`` increases the base value by 50%.
82+
* **scenario**: The name of the scenario to which this perturbation applies.
83+
* **table**: The Temoa parameter (database table) to perturb (e.g., ``cost_variable``, ``demand``, ``capacity_factor_process``).
84+
* **filter**: A dictionary of column-value pairs used to identify specific rows. Since the extension uses the dynamic manifest from ``HybridLoader``, any column belonging to the table's index can be used for filtering.
85+
* **action**: The operation to perform. Supported values:
86+
* ``multiply``: Multiply the base value by ``value``.
87+
* ``add``: Add ``value`` to the base value.
88+
* ``set``: Replace the base value with ``value``.
89+
* **value**: The numeric value used in the perturbation action.
8790

8891
How it Works
8992
------------
@@ -101,6 +104,4 @@ Limitations
101104
-----------
102105

103106
* **Two-Stage Only**: While ``mpi-sppy`` supports multi-stage stochastic programming, the current Temoa integration is tailored for two-stage problems where the first time period constitutes the first stage.
104-
* **Cost/Demand Multipliers**: Perturbations are currently implemented as multipliers on base values. Absolute value overrides are not yet supported.
105-
* **Solver Support**: The extension has been primarily tested with the ``appsi_highs`` and ``cbc`` solvers via Pyomo.
106107
* **Result Persistence**: Currently, only the expected objective value and summary logs are produced. Detailed per-scenario result persistence to the database is under development.

temoa/extensions/stochastics/scenario_creator.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,50 @@
1-
import sqlite3
21
import logging
2+
import sqlite3
3+
from typing import TYPE_CHECKING, Any
4+
35
import pyomo.environ as pyo
46
from mpisppy.utils.sputils import attach_root_node
57

6-
from temoa.core.config import TemoaConfig
7-
from temoa.extensions.stochastics.stochastic_config import StochasticConfig
8-
from temoa.data_io.hybrid_loader import HybridLoader
98
from temoa._internal.run_actions import build_instance
109
from temoa.components.costs import period_cost_rule
10+
from temoa.data_io.hybrid_loader import HybridLoader
11+
12+
if TYPE_CHECKING:
13+
from temoa.core.config import TemoaConfig
14+
from temoa.extensions.stochastics.stochastic_config import StochasticConfig
1115

1216
logger = logging.getLogger(__name__)
1317

14-
def scenario_creator(scenario_name, **kwargs):
18+
19+
def scenario_creator(scenario_name: str, **kwargs: Any) -> pyo.ConcreteModel:
1520
"""
1621
Creator for mpi-sppy scenarios.
1722
1823
Args:
1924
scenario_name (str): Name of the scenario to create.
2025
**kwargs: Must contain 'temoa_config' and 'stoch_config'.
2126
"""
27+
if 'temoa_config' not in kwargs or 'stoch_config' not in kwargs:
28+
raise ValueError("scenario_creator requires 'temoa_config' and 'stoch_config' in kwargs")
29+
2230
temoa_config: TemoaConfig = kwargs['temoa_config']
2331
stoch_config: StochasticConfig = kwargs['stoch_config']
2432

2533
# 1. Load base data
26-
with sqlite3.connect(temoa_config.input_database) as con:
27-
hybrid_loader = HybridLoader(db_connection=con, config=temoa_config)
28-
data_dict = hybrid_loader.create_data_dict(myopic_index=None)
34+
try:
35+
with sqlite3.connect(temoa_config.input_database) as con:
36+
hybrid_loader = HybridLoader(db_connection=con, config=temoa_config)
37+
data_dict = hybrid_loader.create_data_dict(myopic_index=None)
38+
39+
# Build a map of table -> index columns from the manifest
40+
# For each LoadItem, the index columns are all but the last one (which is the value)
41+
table_index_map = {}
42+
for item in hybrid_loader.manifest:
43+
if item.table not in table_index_map and item.columns:
44+
table_index_map[item.table] = item.columns[:-1]
45+
except Exception as e:
46+
logger.exception('Failed to connect to database %s', temoa_config.input_database)
47+
raise RuntimeError(f'Failed to connect to database {temoa_config.input_database}') from e
2948

3049
# 2. Apply perturbations for this scenario
3150
for p in stoch_config.perturbations:
@@ -34,30 +53,25 @@ def scenario_creator(scenario_name, **kwargs):
3453

3554
target_param = data_dict.get(p.table)
3655
if target_param is None:
37-
logger.warning(f"Table {p.table} not found in data_dict for scenario {scenario_name}")
56+
logger.warning(
57+
'Table %s not found in data_dict for scenario %s', p.table, scenario_name
58+
)
3859
continue
3960

4061
# target_param is {(idx...): value}
4162
# We need to find entries matching p.filter
63+
index_cols = table_index_map.get(p.table)
64+
if index_cols is None:
65+
logger.warning(
66+
'Table %s not found in manifest; cannot map indices for scenario %s',
67+
p.table,
68+
scenario_name,
69+
)
70+
continue
71+
4272
for idx_tuple, current_val in list(target_param.items()):
43-
# Temoa v4 parameter indices:
44-
# cost_variable: (region, period, tech, vintage)
45-
# cost_invest: (region, tech, period)
46-
# cost_fixed: (region, period, tech, vintage)
47-
48-
# Map index tuple to names based on table
49-
# This is a bit brittle, but handles common cases
50-
index_map = {}
51-
if p.table == 'cost_variable':
52-
index_map = {'region': idx_tuple[0], 'period': idx_tuple[1], 'tech': idx_tuple[2], 'vintage': idx_tuple[3]}
53-
elif p.table == 'cost_invest':
54-
index_map = {'region': idx_tuple[0], 'tech': idx_tuple[1], 'period': idx_tuple[2]}
55-
elif p.table == 'cost_fixed':
56-
index_map = {'region': idx_tuple[0], 'period': idx_tuple[1], 'tech': idx_tuple[2], 'vintage': idx_tuple[3]}
57-
else:
58-
# Generic fallback if filter names match index positions?
59-
# For now, only support these three
60-
pass
73+
# Map index tuple to names based on table manifest
74+
index_map = dict(zip(index_cols, idx_tuple, strict=True))
6175

6276
# Check if filter matches
6377
match = True
@@ -79,10 +93,16 @@ def scenario_creator(scenario_name, **kwargs):
7993
instance = build_instance(data_portal, silent=True)
8094

8195
# 4. Attach root node (Stage 1)
82-
periods = sorted(list(instance.time_optimize))
96+
periods = sorted(instance.time_optimize)
8397
first_period = periods[0]
8498

85-
prob = stoch_config.scenarios.get(scenario_name, {}).get('probability', 1.0)
99+
prob = stoch_config.scenarios.get(scenario_name)
100+
if prob is None:
101+
logger.warning(
102+
"Scenario '%s' not found in stochastic config probabilities; defaulting to 1.0",
103+
scenario_name,
104+
)
105+
prob = 1.0
86106
instance._mpisppy_probability = prob
87107

88108
# First stage variables: v_new_capacity[*, *, first_period]
@@ -91,6 +111,15 @@ def scenario_creator(scenario_name, **kwargs):
91111
if p == first_period:
92112
first_stage_vars.append(instance.v_new_capacity[r, t, p])
93113

114+
if not first_stage_vars:
115+
logger.error(
116+
'No first-stage variables (v_new_capacity for period %s) found for scenario %s. '
117+
'Stochastic optimization requires at least one first-stage decision.',
118+
first_period,
119+
scenario_name,
120+
)
121+
raise ValueError(f'No first-stage variables found for scenario {scenario_name}')
122+
94123
# First stage cost: PeriodCost[first_period]
95124
# We can use the period_cost_rule directly
96125
first_stage_cost_expr = period_cost_rule(instance, first_period)
Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,83 @@
1+
import logging
2+
import tomllib
13
from dataclasses import dataclass, field
24
from pathlib import Path
35
from typing import Any
4-
import tomllib
6+
7+
logger = logging.getLogger(__name__)
8+
59

610
@dataclass
711
class Perturbation:
812
scenario: str
913
table: str
10-
column: str
1114
filter: dict[str, Any]
1215
action: str # 'multiply', 'add', 'set'
1316
value: float
1417

18+
def __post_init__(self) -> None:
19+
allowed_actions = {'multiply', 'add', 'set'}
20+
if self.action not in allowed_actions:
21+
raise ValueError(
22+
f"Invalid perturbation action '{self.action}'; must be one of {allowed_actions}"
23+
)
24+
25+
1526
@dataclass
1627
class StochasticConfig:
17-
scenarios: dict[str, dict[str, float]] # scenario_name -> {probability: float}
28+
scenarios: dict[str, float] # scenario_name -> probability
1829
perturbations: list[Perturbation] = field(default_factory=list)
30+
solver_options: dict[str, Any] = field(default_factory=dict)
1931

2032
@classmethod
2133
def from_toml(cls, path: Path) -> 'StochasticConfig':
2234
with open(path, 'rb') as f:
2335
data = tomllib.load(f)
2436

25-
scenarios = data.get('scenarios', {})
37+
scenarios_raw = data.get('scenarios', {})
38+
scenarios = {}
39+
for name, val in scenarios_raw.items():
40+
if isinstance(val, dict):
41+
scenarios[name] = float(val.get('probability', 1.0))
42+
else:
43+
scenarios[name] = float(val)
44+
45+
# Validate probability distribution
46+
if scenarios:
47+
total_prob = sum(scenarios.values())
48+
if not (0.99 <= total_prob <= 1.01):
49+
logger.warning(
50+
'Stochastic scenario probabilities sum to %s; usually they should sum to ~1.0',
51+
total_prob,
52+
)
53+
2654
perturbations_data = data.get('perturbations', [])
2755
perturbations = []
28-
for p in perturbations_data:
29-
perturbations.append(Perturbation(
30-
scenario=p['scenario'],
31-
table=p['table'],
32-
column=p['column'],
33-
filter=p['filter'],
34-
action=p.get('action', 'set'),
35-
value=p['value']
36-
))
37-
38-
return cls(scenarios=scenarios, perturbations=perturbations)
56+
for i, p in enumerate(perturbations_data):
57+
try:
58+
scenario_name = p['scenario']
59+
if scenario_name not in scenarios:
60+
raise ValueError(
61+
f'Perturbation at index {i} references nonexistent scenario: '
62+
f"'{scenario_name}'. Available scenarios: {list(scenarios.keys())}"
63+
)
64+
65+
perturbations.append(
66+
Perturbation(
67+
scenario=scenario_name,
68+
table=p['table'],
69+
filter=p['filter'],
70+
action=p.get('action', 'set'),
71+
value=p['value'],
72+
)
73+
)
74+
except KeyError as e:
75+
raise ValueError(f'Perturbation at index {i} is missing required field: {e}') from e
76+
77+
solver_options = data.get('solver_options', {})
78+
79+
return cls(
80+
scenarios=scenarios,
81+
perturbations=perturbations,
82+
solver_options=solver_options,
83+
)

0 commit comments

Comments
 (0)