Skip to content

Model base classes¤

feedbax.AbstractModel (Module, Generic[StateT]) ¤

Base class for all Feedbax models.

bounds: PyTree[StateBounds] property ¤

Suggested bounds on the state variables.

memory_spec: PyTree[bool] property ¤

Specifies which states should typically be remembered.

Info

This is not used by the model itself, but may be used by an AbstractIterator model that wraps it. When iterating very large models for many time steps, storing all states may be limiting because of available memory; not storing certain parts of the state across all time steps may be helpful.

__call__ (input: PyTree[Array],state: StateT,key: PRNGKeyArray) -> StateT
abstractmethod ¤

Return an updated state, given inputs and a prior state.

Parameters:

Name Type Description Default
input PyTree[Array]

The inputs to the model.

required
state StateT

The prior state associated with the model.

required
key PRNGKeyArray

A random key used for model operations.

required
step () -> Module
¤

The part of the model PyTree specifying a single time step of the model.

For non-iterated models, this should trivially return step.

state_consistency_update (state: StateT) -> StateT
¤

Make sure the model state is self-consistent.

Info

The default behaviour is to just return the same state that was passed.

In some models, multiple representations of the same information are kept, but only one is initialized by the task. A typical example is a task that initializes a model's effector state, e.g. the location of the endpoint of the arm at the start of a reach. In the case of a two-link arm, the location of the arm's endpoint directly constrains the joint angles. For the state to be consistent, the values for the joint angles should match the values for effector position. Importantly, the joint angles are the representation that is used by the ODE describing the arm. Therefore we call state_consistency_update after initializing the state and before the first forward pass the model, to make sure the joint angles are consistent with the effector position.

In this way we avoid having to specify redundant information in AbstractTask, and each model can handle the logic of what makes a state consistent, with respect to its own operations.

init (*,key: PRNGKeyArray) -> StateT
abstractmethod ¤

Return a default state for the model.

feedbax.ModelInput (Module) ¤

PyTree that contains all inputs to a model.

feedbax.wrap_stateless_callable (callable: Callable)
¤

Makes a 'stateless' callable compatible with state-passing.

Info

AbstractStagedModel defines its operations as transformations to parts of a state PyTree. Each stage of a model consists of passing a particular substate of the model to a callable that operates on it, returning an updated substate. However, in some cases the new substate does not depend on the previous substate, but is generated entirely from some other inputs.

For example, a linear neural network layer outputs an array of a certain shape, but only requires some input array—and not its prior output (state) array—to do so. We can use a module like eqx.nn.Linear to update a part of a model's state, as the callable of one of its model stages; however, the signature of Linear only accepts input, and not state. By wrapping in this function, we can make it accept state as well, though it is simply discarded.

Parameters:

Name Type Description Default
callable Callable

The callable to wrap.

required

feedbax.wrap_stateless_keyless_callable (callable: Callable)
¤

Like wrap_stateless_callable, for a callable that also takes no key.

Parameters:

Name Type Description Default
callable Callable

The callable to wrap.

required