Iterating a model¤
feedbax.iterate.Iterator
(AbstractIterator[StateT])
¤
Applies a model repeatedly, carrying state. Returns history for all states.
Note
If memory is not an issue, this class is preferred to
ForgetfulIterator
as it lacks the state partitioning overhead, and
is therefore faster.
For very large state PyTrees, however, it may be preferable to use
ForgetfulIterator
to save memory.
Attributes:
Name | Type | Description |
---|---|---|
step |
AbstractModel[StateT]
|
The model to be iterated. |
n_steps |
int
|
The number of steps to iterate for. |
__call__
(
input
: PyTree
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: PyTree
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
PyTree
|
The input to the model. |
required |
state |
StateT
|
The initial state of the model to be iterated. |
required |
key |
PRNGKeyArray
|
Determines the pseudo-randomness in model execution. |
required |
init
(
*,
key
: PRNGKeyArray
)
->
StateT
¤
init
(
*,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax.iterate.AbstractIterator
feedbax.iterate.ForgetfulIterator
(AbstractIterator[StateT])
¤
Applies a model repeatedly, carrying state. Returns history for a subset of states.
Attributes:
Name | Type | Description |
---|---|---|
n_steps |
int
|
The number of steps to iterate for. |
step |
int
|
The model to be iterated. |
memory_spec |
PyTree[bool]
|
A PyTree of bools—a prefix of |
__call__
(
input
: PyTree
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: PyTree
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
PyTree
|
The input to the model. |
required |
state |
StateT
|
The initial state of the model to be iterated. |
required |
key |
PRNGKeyArray
|
Determines the pseudo-randomness in model execution. |
required |
init
(
*,
key
: PRNGKeyArray
)
->
StateT
¤
init
(
*,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax.iterate.AbstractIterator
Abstract base classes¤
feedbax.iterate.AbstractIterator
(AbstractModel[StateT])
¤
Base class for models which iterate other models.
step: Module
property
¤
The model to be iterated.
__call__
(
input
: PyTree[Array]
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: PyTree[Array]
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax._model.AbstractModel
init
(
*,
key
: PRNGKeyArray
)
->
StateT
¤
init
(
*,
key
: PRNGKeyArray
)
->
StateT
Returns an initial state for the iterated model.