Skip to content

Iterating a model¤

feedbax.iterate.Iterator (AbstractIterator[StateT]) ¤

Applies a model repeatedly, carrying state. Returns history for all states.


If memory is not an issue, this class is preferred to ForgetfulIterator as it lacks the state partitioning overhead, and is therefore faster.

For very large state PyTrees, however, it may be preferable to use ForgetfulIterator to save memory.


Name Type Description
step AbstractModel[StateT]

The model to be iterated.

n_steps int

The number of steps to iterate for.

__call__ (input: PyTree,state: StateT,key: PRNGKeyArray) -> StateT


Name Type Description Default
input PyTree

The input to the model.

state StateT

The initial state of the model to be iterated.

key PRNGKeyArray

Determines the pseudo-randomness in model execution.

init (*,key: PRNGKeyArray) -> StateT

feedbax.iterate.ForgetfulIterator (AbstractIterator[StateT]) ¤

Applies a model repeatedly, carrying state. Returns history for a subset of states.


Name Type Description
n_steps int

The number of steps to iterate for.

step int

The model to be iterated.

memory_spec PyTree[bool]

A PyTree of bools—a prefix of StateT indicating which states to store.

__call__ (input: PyTree,state: StateT,key: PRNGKeyArray) -> StateT


Name Type Description Default
input PyTree

The input to the model.

state StateT

The initial state of the model to be iterated.

key PRNGKeyArray

Determines the pseudo-randomness in model execution.

init (*,key: PRNGKeyArray) -> StateT

Abstract base classes¤

feedbax.iterate.AbstractIterator (AbstractModel[StateT]) ¤

Base class for models which iterate other models.

step: Module property ¤

The model to be iterated.

__call__ (input: PyTree[Array],state: StateT,key: PRNGKeyArray) -> StateT
init (*,key: PRNGKeyArray) -> StateT

Returns an initial state for the iterated model.