Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

ifelse does not work with arbitrary data structures #1017

@rlouf

Description

@rlouf

What I expect

I expect ifelse to work when I pass two arbitrarily nested python data structures that have identical structures, and whose leaves are TensorVariables of identical shape and dtype.

This would greatly facilitate the use of aesara for "complex" projects such as aehmc.

What I observe

It works with a tuple of TensorVariable

An IfElse Op can be created passing a tuple or a list of TensorVariables for each branch of the condition:

import aesara.tensor as at
from aesara.ifelse import ifelse

cond = at.as_tensor(0, dtype=bool)

q_left = at.vector()
energy_left = at.scalar()
energy_grad_left = at.vector()
state_left = (q_left, energy_left, energy_grad_left)

q_right = at.vector()
energy_right = at.scalar()
energy_grad_right = at.vector()
state_right = (q_right, energy_right, energy_grad_right)

new_state = ifelse(cond, state_left, state_right)

It does not work with nested tuples

However, it does not work with nested structures. Here is a simplified version of something I have encountered in aehmc:

weight_left = at.scalar()
left = (state_left, weight_left)

weight_right = at.scalar()
right = (state_right, weight_right)

new = ifelse(cond, left, right)

Which returns a TypeError.

It does not work with dictionaries

In the same way, it will not work with dictionaries as input, the following also returns a TypeError:

import aesara.tensor as at
from aesara.tensor.var import TensorVariable
from aesara.ifelse import ifelse

state_left = {"q": at.vector(), "energy": at.scalar(), "energy_grad": at.vector()}
state_right = {"q": at.vector(), "energy": at.scalar(), "energy_grad": at.vector()}

cond = at.as_tensor(0, dtype=bool)
state = ifelse(cond, state_left, state_right)

It does not work with namedtuples

IfElse returns a tuple when NamedTuples are passed an input, causing errors downstream.

from typing import NamedTuple

import aesara.tensor as at
from aesara.tensor.var import TensorVariable
from aesara.ifelse import ifelse


class State(NamedTuple):
    q: TensorVariable
    energy: TensorVariable
    energy_grad: TensorVariable

state_left = State(at.vector(), at.scalar(), at.vector())
state_right = State(at.vector(), at.scalar(), at.vector())

cond = at.as_tensor(0, dtype=bool)
state = ifelse(cond, state_left, state_right)
state.q

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions