Skip to content

Models with stages¤

In this example, we'll check out

  • the structure of Feedbax staged models;
  • how their structure allows them to be modified with arbitrary interventions;
  • how to write your own staged models.

This is an advanced example that assumes familiarity with type annotations and dataclasses/Equinox modules.

All Feedbax models are defined as Equinox modules. A module gathers together model parameters with the model computation (function) they parametrize. A module may be composed of other modules, in which case it forms a nested structure of model components and their respective parameters. The examples of SimpleFeedback modules we define below are all of this kind.

Once a model object is constructed, we can call it like a function. We're allowed to do that because we define its __call__ method.

The structure of a staged model¤

A lot can happen in __call__.

Consider the following simplified definition of feedbax.bodies.SimpleFeedback, which is a model of a single time step in which a neural network, after receiving some sensory feedback, sends a command to a mechanical model:

import equinox as eqx

from feedbax.mechanics import Mechanics, MechanicsState
from feedbax.nn import SimpleStagedNetwork, NetworkState
from feedbax.channel import Channel, ChannelState


class SimpleFeedbackState(eqx.Module):
    mechanics: MechanicsState
    net: NetworkState
    feedback: ChannelState


class SimpleFeedback(eqx.Module):
    """Model of one step around a feedback loop between a neural network 
    and a mechanical model.
    """
    mechanics: Mechanics 
    net: SimpleStagedNetwork
    feedback_channel: Channel
    where_feedback: Callable = lambda state: state.mechanics.plant.skeleton

    def __call__(
        self, 
        input: PyTree,  
        state: SimpleFeedbackState, 
        key: PRNGKeyArray,
    ) -> SimpleFeedbackState:

        key1, key2 = jr.split(key)

        feedback_state = self.feedback_channel(
            self.where_feedback(state),
            state.feedback,
            key1,
        )

        network_state = self.net(
            (input, feedback_state.output), 
            state.network, 
            key2
        )

        mechanics_state = self.mechanics(
            network_state.output, 
            state.mechanics
        )        

        return SimpleFeedbackState(
            mechanics=mechanics_state, 
            network=network_state,
            feedback=feedback_state,
        )

    # ...
    # We've omitted a bunch of other stuff from this definition!

First, notice that SimpleFeedback is an Equinox Module subclass. It's not obvious from this code alone, but Mechanics and Channel are also Module subclasses, with their own parameters and components.

Observe the following about __call__:

  • it takes a SimpleFeedbackState object, and returns a new one constructed from all the updated substates. This is where "the state of the model" is actually stored—not in SimpleFeedback itself.
  • it also takes an argument input. This argument is a PyTree that includes any inputs to the model that aren't part of its state.
  • it contains several steps. Each step involves calling one of the components of the model, such as self.feedback_channel or self.net. Each component is passed some part of the model state, and returns an updated version of that part. Like the parent model, each component also takes an argument of additional information as input.

    The components of the model that are called during __call__ are

    • is a Channel object, which is a type of Equinox Module.
    • takes as input self.where_feedback(state), which is the part of state we want to store in the state of the feedback channel, to be retrieved on some later time step, depending on the delay.
    • takes state.feedback as its prior state. Looking at SimpleFeedbackState: state.feedback is a ChannelState object. This makes sense as ChannelState is the type of state PyTree that Channel depends on.
    • returns an updated Channel_State, which we assign to feedback_state.

    Note that the default for self.where_feedback is lambda state: state.mechanics.plant.skeleton, which means that our sensory feedback consists of the full state of the skeleton—typically, the positions and velocities of some joints.

    • is a SimpleStagedNetwork object, which is a type of Equinox Module.
    • takes as input (input, feedback_state.output). Here, input is the entire argument passed to __call__ itself. Since SimpleFeedback is a top-level model, its input will consist of the trial-by-trial information needed to complete the task. The neural network is proper recipient of this information, in this case.
    • takes state.network as prior state. This is a NetworkState object.
    • returns an updated NetworkState, which we assign to network_state.

    This is the only step in the model that receives the input that was passed to SimpleFeedback itself. This is because the input to the model is typically information the network needs to complete the task—say, the position of the goal it should reach to. The input to all of the other model steps is some other part of the model state.

    • is a Mechanics object, which is a type of Equinox Module.
    • takes as input network_state.output, where network_state contains the updated NetworkState returned by self.net: network_state.output is the command we want to send to the mechanical model.
    • takes state.mechanics as its prior state. This is a MechanicsState object.
    • returns an updated MechanicsState, which we assign to mechanics_state.

Trying to intervene¤

What if we want to interfere with the command the neural network generates, after we call self.net but before we call self.mechanics? We could write a new module with an extra component that operates on NetworkState, and call it at the right moment:

class SimpleFeedbackPerturbNetworkOutput(eqx.Module):
    net: eqx.Module  
    mechanics: Mechanics 
    feedback_channel: Channel
    where_feedback: Callable = lambda state: state.mechanics.plant.skeleton
    intervention: eqx.Module

    def __call__(
        self, 
        input: PyTree,  
        state: SimpleFeedbackState, 
        key: PRNGKeyArray,
    ) -> SimpleFeedbackState:

        key1, key2 = jr.split(key)

        feedback_state = self.feedback_channel(
            self.where_feedback(state),
            state.feedback,
            key1,
        )

        network_state = self.net(
            (input, feedback_state.output), 
            state.network, 
            key2
        )

        # modifies `network_state.output` somehow
        network_state = self.intervention(network_state)

        mechanics_state = self.mechanics(
            network_state.output, 
            state.mechanics
        )        

        return SimpleFeedbackState(
            mechanics=mechanics_state, 
            network=network_state,
            feedback=feedback_state,
        )

It would be pretty inconvenient to have to do this every time we want to intervene a little. Once we have a model, it'd be nice to experiment on it quickly.

Also, if we have a different model that's similar enough to SimpleFeedback that it might make sense to use the same kind of NetworkState intervention that we just used, we wouldn't want to have to manually rewrite that model too.

Thankfully we can do something about this.

A more general way to intervene¤

Start by noticing that each step in the __call__ method of our original SimpleFeedback:

  • is defined as a modification of some part of the model state: each operation we perform returns some part of SimpleFeedbackState;
  • calls a model component in a consistent way: no matter if we're calling self.feedback_channel, self.net, or self.mechanics, our call always looks like self.something(input_to_something, state_associated_with_something, key).

Each component will also need a key, but we won't need to specify how that works for each component individually.

We'll just have to be sure to split up the key passed to __call__, so that each component gets a different key.

That means we can define each step in __call__ with three pieces of information:

  1. Which model component to call—say, self.net;
  2. How to select the input to that model component, from the full input and state passed to SimpleFeedback;
  3. How to select the state associated with (and modified by) that model component, out of the full state passed to SimpleFeedback.

OK, let's try to do that. We'll define an object called ModelStage which holds the three pieces of information required to define each model stage. Then we'll define a model_spec that defines all the stages of our model.

from collections import OrderedDict


class ModelStage(eqx.Module):
    component: Callable
    where_input: Callable
    where_state: Callable


model_spec = OrderedDict({
    'update_feedback': ModelStage(
        # See explanation below for why we define this as a lambda!
        func=lambda self: self.feedback_channel,  
        where_input=lambda input, state: self.where_feedback(state),
        where_state=lambda state: state.feedback,  
    ),
    'net_step': ModelStage(
        func=lambda self: self.net,
        where_input=lambda input, state: (input, state.feedback.output),
        where_state=lambda state: state.net,                
    ),
    'mechanics_step': ModelStage(
        func=lambda self: self.mechanics,
        where_input=lambda input, state: state.net.output,
        where_state=lambda state: state.mechanics,
    ),
})       

Note

The model stages need to be executed in order. Even though a dict does maintain the order of its entries in current Python versions, this is not the case when a dict is transformed using PyTree operations from jax.tree_util. Therefore we use OrderedDict.

Note

Each of the fields in ModelStage is typed as Callable, which means it can be a function, a method, or any object with __call__ defined. In this case, we define them inline as lambda functions.

For where_input and where_state, this is similar to what we've seen in earlier examples. For example, where_input will take the input and state passed to __call__, and return the parts to be passed to the current stage's func.

Why do we define func as lambda self: self.something rather than just self.something? It's to make sure that references to the component "stay fresh". If that doesn't make sense to you, don't worry about it at this point. Just remember that if you write your own staged models, you will need to write your model_spec this way.

In order to insert interventions at arbitrary points, here's what we'll do:

  1. include model_spec as an attribute of SimpleFeedback itself;
  2. define __call__ so that it calls each of the components in model_spec, passing them their respective subsets of the input and state, and using their return value to update the model state.

    Importantly, the way we define __call__ will no longer allow our model stages to assign, or refer, to intermediate variables like feedback_state. This is why in the model_spec we just defined, the input to self.net includes state.feedback.output, where in our original definition of SimpleFeedback we had passed feedback_state.output.

    In our new __call__, we'll update the full model state immediately after each stage, rather than assigning to intermediate variables and then finally constructing a new SimpleFeedbackState. Every stage's inputs are only selected out of the full model state, not out of intermediate copies of parts of the state.

  3. give SimpleFeedback a new attribute intervenors, where we can insert additional components that intervene on the model's state, given the name of the model stage they should be applied before. For example, if this attribute is set to {'mechanics_step': [some_intervention, some_other_intervention]} then some_intervention and some_other_intervention would be called one after the other, immediately before self.mechanics is called.

class SimpleFeedback(eqx.Module):
    net: eqx.Module  
    mechanics: Mechanics 
    feedback_channel: Channel
    where_feedback: Callable = lambda state: state.mechanics.plant.skeleton
    intervenors: dict   

    @property
    def model_spec(self):
        return OrderedDict({
            'update_feedback': ModelStage(
                func=lambda self: self.feedback_channel,  
                where_input=lambda input, state: self.where_feedback(state),
                where_state=lambda state: state.feedback,  
            ),
            'net_step': ModelStage(
                func=lambda self: self.net,
                where_input=lambda input, state: (input, state.feedback.output),
                where_state=lambda state: state.net,                
            ),
            'mechanics_step': ModelStage(
                func=lambda self: self.mechanics,
                where_input=lambda input, state: state.net.output,
                where_state=lambda state: state.mechanics,
            ),
        })    

    def __call__(
        self, 
        input: PyTree,  
        state: SimpleFeedbackState, 
        key: PRNGKeyArray,
    ) -> SimpleFeedbackState: 

        # Get a different key for each stage of the model.
        keys = jr.split(key, len(self.model_spec))

        # Loop through the model stages, pairing them with their keys.
        for (label, stage), key_stage in zip(self.model_spec.items(), keys):

            # Loop through all intervenors assigned to this model stage.
            for intervenor in self.intervenors[label]:
                state = intervenor(state)

            # Get the updated part of the state associated with the stage
            new_component_state = stage.func(
                stage.where_input(input, state),
                stage.where_state(state),
                key_stage,
            )

            # Modify the full model state
            state = eqx.tree_at(
                stage.where_state,  # Part to modify
                state,  # What is modified (full state)
                new_component_state,  # New part to insert
            )

        return state

Our model is now structured so that it's possible to insert interventions among its stages, without rewriting the whole thing each time!

This __call__ method is too general to be limited to SimpleFeedback. In fact, the real feedbax.bodies.SimpleFeedback doesn't define __call__ itself, but inherits it from feedbax.AbstractStagedModel.

Every staged model is a subclass of AbstractStagedModel, and only needs to define model_spec (and a couple of other smaller things).

Defining models as a sequence of named state operations has some additional advantages, beyond being able to insert interventions among the stages. For one, it makes it easy to log the details of our model stages as they are executed, which is useful for debugging.

Pretty printing of model stages¤

Another advantage of staged models is that it's easy to print out a tree of operations, showing the sequence in which the're performed.

Feedbax provides the function pprint_model_spec for this purpose.

import jax
from feedbax import pprint_model_spec
from feedbax.xabdeef import point_mass_nn_simple_reaches

context = point_mass_nn_simple_reaches(key=jax.random.PRNGKey(0))

pprint_model_spec(context.model.step)
update_feedback: MultiModel
nn_step: SimpleStagedNetwork
  hidden: GRUCell
  hidden_nonlinearity: wrapped: identity_func
  readout: wrapped: wrapped
  out_nonlinearity: wrapped: identity_func
mechanics_step: Mechanics
  convert_effector_force: wrapped: PointMass.update_state_given_effector_force
  kinematics_update: DirectForceInput
    clip_skeleton_state: wrapped: DirectForceInput._clip_state
  dynamics_step: wrapped: Mechanics.dynamics_step
  get_effector: wrapped: PointMass.effector

Each line corresponds to a model stage. When the component for that stage is also a staged model, its own stages are printed on the lines that follow, with indentation. For example, "clip_skeleton_state" is a stage of DirectForceInput, which is called as part of the "statics_step" stage of Mechanics.

Writing a staged model¤

The following components are needed to define your own staged model.

  1. A type of state PyTree—that is, a subclass of equinox.Module that defines the state arrays that the model will be able to operate on. The fields of this subclass are either JAX arrays, or state PyTrees associated with the model's components. For example, if the model has a Mechanics component, its state PyTree will have a MechanicsState field.
  2. A final subclass of AbstractStagedModel. This subclass must implement
    • a model_spec property defining, as above, the information needed to call the model stages;
    • an init method that takes a random key, and returns a default model state of the type defined in 1;
    • the field intervenors: dict, where intervenors will be stored.

For example, here's how to define a staged model that contains two neural networks in a loop, where:

  1. the first network receives (without delay) the prior output of the second network;
  2. the output of the first network is passed to a Channel which implements a delay; and
  3. the delayed output of the first network—i.e. added to the channel state during an earlier call to NetworkLoop—is passed to the second network.
from collections import OrderedDict

import equinox as eqx
import jax

from feedbax import AbstractStagedModel, ModelStage
from feedbax.channel import Channel, ChannelState
from feedbax.nn import SimpleStagedNetwork, NetworkState


class NetworkLoopState(eqx.Module):
    net1: NetworkState
    net2: NetworkState
    channel: ChannelState


class NetworkLoop(AbstractStagedModel):
    net1: SimpleStagedNetwork
    net2: SimpleStagedNetwork
    channel: Channel
    intervenors: dict

    @property
    def model_spec(self):
        return OrderedDict({
            'net1_step': ModelStage(
                callable=lambda self: self.net1,
                where_input=lambda input, state: state.net2.output,
                where_output=lambda state: state.net1,
            ),
            'channel': ModelStage(
                callable=lambda self: self.channel,
                where_input=lambda input, state: state.net1.output,
                where_output=lambda state: state.channel,
            ),
            'net2_step': ModelStage(
                callable=lambda self: self.net2,
                where_input=lambda input, state: state.channel.output,
                where_output=lambda state: state.net2,
            ),
        })

    def init(self, *, key):
        keys = jax.random.split(key, 3)
        return NetworkLoopState(
            # Any components that are staged models will also have an
            # init method, which we use to construct their component states.
            net1=self.net1.init(key=keys[0]),
            net2=self.net2.init(key=keys[1]),
            channel=self.channel.init(key=keys[2]),
        )

To construct this model, we have to construct its components. Normally we write a setup function to make this reproducible.

import jax.numpy as jnp
import jax.random as jr


def setup(
    net1_hidden_size,
    net2_hidden_size,
    channel_delay=5,
    channel_noise_std=0.05,
    *,
    key,
):
    key1, key2 = jr.split(key)
    net1 = SimpleStagedNetwork(
        input_size=net2_hidden_size,
        hidden_size=net1_hidden_size,
        key=key1,
    )
    net2 = SimpleStagedNetwork(
        input_size=net1_hidden_size,
        hidden_size=net2_hidden_size,
        key=key2
    )
    channel = Channel(
        channel_delay,
        channel_noise_std,
        input_proto=jnp.zeros(net2_hidden_size)
    )

    return NetworkLoop(net1=net1, net2=net2, channel=channel)

Notice that none of the stages of NetworkLoop pass any part of the input to their respective components—so the networks couldn't be directly receiving any task information. However, we can still imagine optimizing for certain targets inside the network during training. And we could imagine that this is just a motif we would use as part of a larger model.

Using stateless modules and simple functions as components¤

Some components may be able to take one part of the model's state as input, and return another part as output, without having their own associated state. Often these are components whose inputs and outputs are single JAX arrays.

For example, consider the "readout" stage of SimpleStagedNetwork, which operates on NetworkState. The component for this stage is an eqx.nn.Linear layer. Notice that Linear.__call__ only takes a single argument, x: Array. This is the input array to the linear layer, which in our case corresponds to state.hidden—the activity of the hidden layer of our network. The output of the layer is used to update state.output.

Looking back at our generalized definition of __call__ for staged models, each model stage is called like

As is generally the case for Feedbax and Equinox modules, Linear also has an argument key—but being a non-stochastic layer, will simply ignore it if passed.

If the component doesn't take a key argument at all, the solution is similar to what we discuss below.

    # Get the updated part of the state associated with the stage
    new_component_state = stage.func(
        stage.where_input(input, state),
        stage.where_state(state),
        key_stage,
    )

Not including the key, two arguments are passed, after being selected by where_input and where_state. Clearly this will raise an error if stage.func refers to a single-argument module like Linear.

When we define our model_spec, we have to be careful in these cases. Here's a sketch of how we could define this stage:

class StagedNetWithReadout(AbstractStagedModel):
    readout: eqx.nn.Linear
    ...

    @property 
    def model_spec(self):
        return OrderedDict({
            ...,

            'readout': ModelStage(
                func=lambda self: (
                    lambda input, state, *, key: self.readout(input)
                ),
                where_input=lambda input, state: state.hidden,
                where_state=lambda state: state.output,
            )

            ...,
        })

We wrap self.readout inside a second lambda, which passes the where_input(...) argument on to Linear, but simply discards the where_state(...) argument. Note that where_state still needs to be defined, since it's also used to determine which part of the model's state to update with the output of the linear layer.

If you find double-lambdas hard to read, you might prefer the function feedbax.wrap_stateless_keyless_callable:

func=lambda self: wrap_stateless_keyless_callable(self.readout),

There's also a function wrap_stateless_callable for when the key argument should be passed through:

func=lambda self: wrap_stateless_callable(self.stateless_component),