This repository was archived by the owner on Nov 17, 2025. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 152
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
ifelse does not work with arbitrary data structures #1017
Copy link
Copy link
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
What I expect
I expect ifelse to work when I pass two arbitrarily nested python data structures that have identical structures, and whose leaves are TensorVariables of identical shape and dtype.
This would greatly facilitate the use of aesara for "complex" projects such as aehmc.
What I observe
It works with a tuple of TensorVariable
An IfElse Op can be created passing a tuple or a list of TensorVariables for each branch of the condition:
import aesara.tensor as at
from aesara.ifelse import ifelse
cond = at.as_tensor(0, dtype=bool)
q_left = at.vector()
energy_left = at.scalar()
energy_grad_left = at.vector()
state_left = (q_left, energy_left, energy_grad_left)
q_right = at.vector()
energy_right = at.scalar()
energy_grad_right = at.vector()
state_right = (q_right, energy_right, energy_grad_right)
new_state = ifelse(cond, state_left, state_right)It does not work with nested tuples
However, it does not work with nested structures. Here is a simplified version of something I have encountered in aehmc:
weight_left = at.scalar()
left = (state_left, weight_left)
weight_right = at.scalar()
right = (state_right, weight_right)
new = ifelse(cond, left, right)Which returns a TypeError.
It does not work with dictionaries
In the same way, it will not work with dictionaries as input, the following also returns a TypeError:
import aesara.tensor as at
from aesara.tensor.var import TensorVariable
from aesara.ifelse import ifelse
state_left = {"q": at.vector(), "energy": at.scalar(), "energy_grad": at.vector()}
state_right = {"q": at.vector(), "energy": at.scalar(), "energy_grad": at.vector()}
cond = at.as_tensor(0, dtype=bool)
state = ifelse(cond, state_left, state_right)It does not work with namedtuples
IfElse returns a tuple when NamedTuples are passed an input, causing errors downstream.
from typing import NamedTuple
import aesara.tensor as at
from aesara.tensor.var import TensorVariable
from aesara.ifelse import ifelse
class State(NamedTuple):
q: TensorVariable
energy: TensorVariable
energy_grad: TensorVariable
state_left = State(at.vector(), at.scalar(), at.vector())
state_right = State(at.vector(), at.scalar(), at.vector())
cond = at.as_tensor(0, dtype=bool)
state = ifelse(cond, state_left, state_right)
state.qMetadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested