Model base classes¤
feedbax.AbstractModel
(Module, Generic[StateT])
¤
Base class for all Feedbax models.
bounds: PyTree[StateBounds]
property
¤
Suggested bounds on the state variables.
memory_spec: PyTree[bool]
property
¤
Specifies which states should typically be remembered.
Info
This is not used by the model itself, but may be used by an
AbstractIterator
model that wraps it. When iterating very large models for
many time steps, storing all states may be limiting because of available
memory; not storing certain parts of the state across all time steps may
be helpful.
__call__
(
input
: PyTree[Array]
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
abstractmethod
¤
__call__
(
input
: PyTree[Array]
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
Return an updated state, given inputs and a prior state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
PyTree[Array]
|
The inputs to the model. |
required |
state |
StateT
|
The prior state associated with the model. |
required |
key |
PRNGKeyArray
|
A random key used for model operations. |
required |
step
(
)
->
Module
¤
step
(
)
->
Module
The part of the model PyTree specifying a single time step of the model.
For non-iterated models, this should trivially return step
.
state_consistency_update
(
state
: StateT
)
->
StateT
¤
state_consistency_update
(
state
: StateT
)
->
StateT
Make sure the model state is self-consistent.
Info
The default behaviour is to just return the same state that was passed.
In some models, multiple representations of the same information are kept,
but only one is initialized by the task. A typical example is a task that
initializes a model's effector state, e.g. the location of the endpoint of
the arm at the start of a reach. In the case of a two-link arm, the location
of the arm's endpoint directly constrains the joint angles. For the state to
be consistent, the values for the joint angles should match the values for
effector position. Importantly, the joint angles are the representation that
is used by the ODE describing the arm. Therefore we call
state_consistency_update
after initializing the state and before the first
forward pass the model, to make sure the joint angles are consistent with
the effector position.
In this way we avoid having to specify redundant information in
AbstractTask
, and each model can handle the logic of what makes
a state consistent, with respect to its own operations.
init
(
*,
key
: PRNGKeyArray
)
->
StateT
abstractmethod
¤
init
(
*,
key
: PRNGKeyArray
)
->
StateT
Return a default state for the model.
feedbax.ModelInput
(Module)
¤
PyTree that contains all inputs to a model.
feedbax.wrap_stateless_callable
(
callable
: Callable
)
¤
feedbax.wrap_stateless_callable
(
callable
: Callable
)
Makes a 'stateless' callable compatible with state-passing.
Info
AbstractStagedModel
defines its operations as transformations to parts of
a state PyTree. Each stage of a model consists of passing a particular substate
of the model to a callable that operates on it, returning an updated substate.
However, in some cases the new substate does not depend on the previous
substate, but is generated entirely from some other inputs.
For example, a linear neural network layer outputs an array of a certain shape,
but only requires some input array—and not its prior output (state) array—to
do so. We can use a module like eqx.nn.Linear
to update a part of a model's
state, as the callable of one of its model stages; however, the signature of
Linear
only accepts input
, and not state
. By wrapping in this function, we
can make it accept state
as well, though it is simply discarded.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
callable |
Callable
|
The callable to wrap. |
required |
feedbax.wrap_stateless_keyless_callable
(
callable
: Callable
)
¤
feedbax.wrap_stateless_keyless_callable
(
callable
: Callable
)
Like wrap_stateless_callable
, for a callable that also takes no key
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
callable |
Callable
|
The callable to wrap. |
required |