Models with stages¤
In this example, we'll check out
- the structure of Feedbax staged models;
- how their structure allows them to be modified with arbitrary interventions;
- how to write your own staged models.
This is an advanced example that assumes familiarity with type annotations and dataclasses/Equinox modules.
All Feedbax models are defined as Equinox modules. A module gathers together model parameters with the model computation (function) they parametrize. A module may be composed of other modules, in which case it forms a nested structure of model components and their respective parameters. The examples of SimpleFeedback
modules we define below are all of this kind.
Once a model object is constructed, we can call it like a function. We're allowed to do that because we define its __call__
method.
The structure of a staged model¤
A lot can happen in __call__
.
Consider the following simplified definition of feedbax.bodies.SimpleFeedback
, which is a model of a single time step in which a neural network, after receiving some sensory feedback, sends a command to a mechanical model:
import equinox as eqx
from feedbax.mechanics import Mechanics, MechanicsState
from feedbax.nn import SimpleStagedNetwork, NetworkState
from feedbax.channel import Channel, ChannelState
class SimpleFeedbackState(eqx.Module):
mechanics: MechanicsState
net: NetworkState
feedback: ChannelState
class SimpleFeedback(eqx.Module):
"""Model of one step around a feedback loop between a neural network
and a mechanical model.
"""
mechanics: Mechanics
net: SimpleStagedNetwork
feedback_channel: Channel
where_feedback: Callable = lambda state: state.mechanics.plant.skeleton
def __call__(
self,
input: PyTree,
state: SimpleFeedbackState,
key: PRNGKeyArray,
) -> SimpleFeedbackState:
key1, key2 = jr.split(key)
feedback_state = self.feedback_channel(
self.where_feedback(state),
state.feedback,
key1,
)
network_state = self.net(
(input, feedback_state.output),
state.network,
key2
)
mechanics_state = self.mechanics(
network_state.output,
state.mechanics
)
return SimpleFeedbackState(
mechanics=mechanics_state,
network=network_state,
feedback=feedback_state,
)
# ...
# We've omitted a bunch of other stuff from this definition!
First, notice that SimpleFeedback
is an Equinox Module
subclass. It's not obvious from this code alone, but Mechanics
and Channel
are also Module
subclasses, with their own parameters and components.
Observe the following about __call__
:
- it takes a
SimpleFeedbackState
object, and returns a new one constructed from all the updated substates. This is where "the state of the model" is actually stored—not inSimpleFeedback
itself. - it also takes an argument
input
. This argument is a PyTree that includes any inputs to the model that aren't part of its state. -
it contains several steps. Each step involves calling one of the components of the model, such as
self.feedback_channel
orself.net
. Each component is passed some part of the model state, and returns an updated version of that part. Like the parent model, each component also takes an argument of additional information as input.The components of the model that are called during
__call__
are- is a
Channel
object, which is a type of EquinoxModule
. - takes as input
self.where_feedback(state)
, which is the part ofstate
we want to store in the state of the feedback channel, to be retrieved on some later time step, depending on the delay. - takes
state.feedback
as its prior state. Looking atSimpleFeedbackState
:state.feedback
is aChannelState
object. This makes sense asChannelState
is the type of state PyTree thatChannel
depends on. - returns an updated
Channel_State
, which we assign tofeedback_state
.
Note that the default for
self.where_feedback
islambda state: state.mechanics.plant.skeleton
, which means that our sensory feedback consists of the full state of the skeleton—typically, the positions and velocities of some joints.- is a
SimpleStagedNetwork
object, which is a type of EquinoxModule
. - takes as input
(input, feedback_state.output)
. Here,input
is the entire argument passed to__call__
itself. SinceSimpleFeedback
is a top-level model, itsinput
will consist of the trial-by-trial information needed to complete the task. The neural network is proper recipient of this information, in this case. - takes
state.network
as prior state. This is aNetworkState
object. - returns an updated
NetworkState
, which we assign tonetwork_state
.
This is the only step in the model that receives the
input
that was passed toSimpleFeedback
itself. This is because the input to the model is typically information the network needs to complete the task—say, the position of the goal it should reach to. The input to all of the other model steps is some other part of the model state.- is a
Mechanics
object, which is a type of EquinoxModule
. - takes as input
network_state.output
, wherenetwork_state
contains the updatedNetworkState
returned byself.net
:network_state.output
is the command we want to send to the mechanical model. - takes
state.mechanics
as its prior state. This is aMechanicsState
object. - returns an updated
MechanicsState
, which we assign tomechanics_state
.
- is a
Trying to intervene¤
What if we want to interfere with the command the neural network generates, after we call self.net
but before we call self.mechanics
? We could write a new module with an extra component that operates on NetworkState
, and call it at the right moment:
class SimpleFeedbackPerturbNetworkOutput(eqx.Module):
net: eqx.Module
mechanics: Mechanics
feedback_channel: Channel
where_feedback: Callable = lambda state: state.mechanics.plant.skeleton
intervention: eqx.Module
def __call__(
self,
input: PyTree,
state: SimpleFeedbackState,
key: PRNGKeyArray,
) -> SimpleFeedbackState:
key1, key2 = jr.split(key)
feedback_state = self.feedback_channel(
self.where_feedback(state),
state.feedback,
key1,
)
network_state = self.net(
(input, feedback_state.output),
state.network,
key2
)
# modifies `network_state.output` somehow
network_state = self.intervention(network_state)
mechanics_state = self.mechanics(
network_state.output,
state.mechanics
)
return SimpleFeedbackState(
mechanics=mechanics_state,
network=network_state,
feedback=feedback_state,
)
It would be pretty inconvenient to have to do this every time we want to intervene a little. Once we have a model, it'd be nice to experiment on it quickly.
Also, if we have a different model that's similar enough to SimpleFeedback
that it might make sense to use the same kind of NetworkState
intervention that we just used, we wouldn't want to have to manually rewrite that model too.
Thankfully we can do something about this.
A more general way to intervene¤
Start by noticing that each step in the __call__
method of our original SimpleFeedback
:
- is defined as a modification of some part of the model state: each operation we perform returns some part of
SimpleFeedbackState
; - calls a model component in a consistent way: no matter if we're calling
self.feedback_channel
,self.net
, orself.mechanics
, our call always looks likeself.something(input_to_something, state_associated_with_something, key)
.
Each component will also need a key
, but we won't need to specify how that works for each component individually.
We'll just have to be sure to split up the key
passed to __call__
, so that each component gets a different key.
That means we can define each step in __call__
with three pieces of information:
- Which model component to call—say,
self.net
; - How to select the input to that model component, from the full
input
andstate
passed toSimpleFeedback
; - How to select the state associated with (and modified by) that model component, out of the full
state
passed toSimpleFeedback
.
OK, let's try to do that. We'll define an object called ModelStage
which holds the three pieces of information required to define each model stage. Then we'll define a model_spec
that defines all the stages of our model.
from collections import OrderedDict
class ModelStage(eqx.Module):
component: Callable
where_input: Callable
where_state: Callable
model_spec = OrderedDict({
'update_feedback': ModelStage(
# See explanation below for why we define this as a lambda!
func=lambda self: self.feedback_channel,
where_input=lambda input, state: self.where_feedback(state),
where_state=lambda state: state.feedback,
),
'net_step': ModelStage(
func=lambda self: self.net,
where_input=lambda input, state: (input, state.feedback.output),
where_state=lambda state: state.net,
),
'mechanics_step': ModelStage(
func=lambda self: self.mechanics,
where_input=lambda input, state: state.net.output,
where_state=lambda state: state.mechanics,
),
})
Note
The model stages need to be executed in order. Even though a dict
does maintain the order of its entries in current Python versions, this is not the case when a dict
is transformed using PyTree operations from jax.tree_util
. Therefore we use OrderedDict
.
Note
Each of the fields in ModelStage
is typed as Callable
, which means it can be a function, a method, or any object with __call__
defined. In this case, we define them inline as lambda
functions.
For where_input
and where_state
, this is similar to what we've seen in earlier examples. For example, where_input
will take the input
and state
passed to __call__
, and return the parts to be passed to the current stage's func
.
Why do we define func
as lambda self: self.something
rather than just self.something
? It's to make sure that references to the component "stay fresh". If that doesn't make sense to you, don't worry about it at this point. Just remember that if you write your own staged models, you will need to write your model_spec
this way.
In order to insert interventions at arbitrary points, here's what we'll do:
- include
model_spec
as an attribute ofSimpleFeedback
itself; -
define
__call__
so that it calls each of the components inmodel_spec
, passing them their respective subsets of the input and state, and using their return value to update the model state.Importantly, the way we define
__call__
will no longer allow our model stages to assign, or refer, to intermediate variables likefeedback_state
. This is why in themodel_spec
we just defined, the input toself.net
includesstate.feedback.output
, where in our original definition ofSimpleFeedback
we had passedfeedback_state.output
.In our new
__call__
, we'll update the full model state immediately after each stage, rather than assigning to intermediate variables and then finally constructing a newSimpleFeedbackState
. Every stage's inputs are only selected out of the full model state, not out of intermediate copies of parts of the state. -
give
SimpleFeedback
a new attributeintervenors
, where we can insert additional components that intervene on the model's state, given the name of the model stage they should be applied before. For example, if this attribute is set to{'mechanics_step': [some_intervention, some_other_intervention]}
thensome_intervention
andsome_other_intervention
would be called one after the other, immediately beforeself.mechanics
is called.
class SimpleFeedback(eqx.Module):
net: eqx.Module
mechanics: Mechanics
feedback_channel: Channel
where_feedback: Callable = lambda state: state.mechanics.plant.skeleton
intervenors: dict
@property
def model_spec(self):
return OrderedDict({
'update_feedback': ModelStage(
func=lambda self: self.feedback_channel,
where_input=lambda input, state: self.where_feedback(state),
where_state=lambda state: state.feedback,
),
'net_step': ModelStage(
func=lambda self: self.net,
where_input=lambda input, state: (input, state.feedback.output),
where_state=lambda state: state.net,
),
'mechanics_step': ModelStage(
func=lambda self: self.mechanics,
where_input=lambda input, state: state.net.output,
where_state=lambda state: state.mechanics,
),
})
def __call__(
self,
input: PyTree,
state: SimpleFeedbackState,
key: PRNGKeyArray,
) -> SimpleFeedbackState:
# Get a different key for each stage of the model.
keys = jr.split(key, len(self.model_spec))
# Loop through the model stages, pairing them with their keys.
for (label, stage), key_stage in zip(self.model_spec.items(), keys):
# Loop through all intervenors assigned to this model stage.
for intervenor in self.intervenors[label]:
state = intervenor(state)
# Get the updated part of the state associated with the stage
new_component_state = stage.func(
stage.where_input(input, state),
stage.where_state(state),
key_stage,
)
# Modify the full model state
state = eqx.tree_at(
stage.where_state, # Part to modify
state, # What is modified (full state)
new_component_state, # New part to insert
)
return state
Our model is now structured so that it's possible to insert interventions among its stages, without rewriting the whole thing each time!
This __call__
method is too general to be limited to SimpleFeedback
. In fact, the real feedbax.bodies.SimpleFeedback
doesn't define __call__
itself, but inherits it from feedbax.AbstractStagedModel
.
Every staged model is a subclass of AbstractStagedModel
, and only needs to define model_spec
(and a couple of other smaller things).
Defining models as a sequence of named state operations has some additional advantages, beyond being able to insert interventions among the stages. For one, it makes it easy to log the details of our model stages as they are executed, which is useful for debugging.
Pretty printing of model stages¤
Another advantage of staged models is that it's easy to print out a tree of operations, showing the sequence in which the're performed.
Feedbax provides the function pprint_model_spec
for this purpose.
import jax
from feedbax import pprint_model_spec
from feedbax.xabdeef import point_mass_nn_simple_reaches
context = point_mass_nn_simple_reaches(key=jax.random.PRNGKey(0))
pprint_model_spec(context.model.step)
Each line corresponds to a model stage. When the component for that stage is also a staged model, its own stages are printed on the lines that follow, with indentation. For example, "clip_skeleton_state"
is a stage of DirectForceInput
, which is called as part of the "statics_step"
stage of Mechanics
.
Writing a staged model¤
The following components are needed to define your own staged model.
- A type of state PyTree—that is, a subclass of
equinox.Module
that defines the state arrays that the model will be able to operate on. The fields of this subclass are either JAX arrays, or state PyTrees associated with the model's components. For example, if the model has aMechanics
component, its state PyTree will have aMechanicsState
field. - A final subclass of
AbstractStagedModel
. This subclass must implement- a
model_spec
property defining, as above, the information needed to call the model stages; - an
init
method that takes a random key, and returns a default model state of the type defined in 1; - the field
intervenors: dict
, where intervenors will be stored.
- a
For example, here's how to define a staged model that contains two neural networks in a loop, where:
- the first network receives (without delay) the prior output of the second network;
- the output of the first network is passed to a
Channel
which implements a delay; and - the delayed output of the first network—i.e. added to the channel state during an earlier call to
NetworkLoop
—is passed to the second network.
from collections import OrderedDict
import equinox as eqx
import jax
from feedbax import AbstractStagedModel, ModelStage
from feedbax.channel import Channel, ChannelState
from feedbax.nn import SimpleStagedNetwork, NetworkState
class NetworkLoopState(eqx.Module):
net1: NetworkState
net2: NetworkState
channel: ChannelState
class NetworkLoop(AbstractStagedModel):
net1: SimpleStagedNetwork
net2: SimpleStagedNetwork
channel: Channel
intervenors: dict
@property
def model_spec(self):
return OrderedDict({
'net1_step': ModelStage(
callable=lambda self: self.net1,
where_input=lambda input, state: state.net2.output,
where_output=lambda state: state.net1,
),
'channel': ModelStage(
callable=lambda self: self.channel,
where_input=lambda input, state: state.net1.output,
where_output=lambda state: state.channel,
),
'net2_step': ModelStage(
callable=lambda self: self.net2,
where_input=lambda input, state: state.channel.output,
where_output=lambda state: state.net2,
),
})
def init(self, *, key):
keys = jax.random.split(key, 3)
return NetworkLoopState(
# Any components that are staged models will also have an
# init method, which we use to construct their component states.
net1=self.net1.init(key=keys[0]),
net2=self.net2.init(key=keys[1]),
channel=self.channel.init(key=keys[2]),
)
To construct this model, we have to construct its components. Normally we write a setup function to make this reproducible.
import jax.numpy as jnp
import jax.random as jr
def setup(
net1_hidden_size,
net2_hidden_size,
channel_delay=5,
channel_noise_std=0.05,
*,
key,
):
key1, key2 = jr.split(key)
net1 = SimpleStagedNetwork(
input_size=net2_hidden_size,
hidden_size=net1_hidden_size,
key=key1,
)
net2 = SimpleStagedNetwork(
input_size=net1_hidden_size,
hidden_size=net2_hidden_size,
key=key2
)
channel = Channel(
channel_delay,
channel_noise_std,
input_proto=jnp.zeros(net2_hidden_size)
)
return NetworkLoop(net1=net1, net2=net2, channel=channel)
Notice that none of the stages of NetworkLoop
pass any part of the input
to their respective components—so the networks couldn't be directly receiving any task information. However, we can still imagine optimizing for certain targets inside the network during training. And we could imagine that this is just a motif we would use as part of a larger model.
Using stateless modules and simple functions as components¤
Some components may be able to take one part of the model's state as input, and return another part as output, without having their own associated state. Often these are components whose inputs and outputs are single JAX arrays.
For example, consider the "readout"
stage of SimpleStagedNetwork
, which operates on NetworkState
. The component for this stage is an eqx.nn.Linear
layer. Notice that Linear.__call__
only takes a single argument, x: Array
. This is the input array to the linear layer, which in our case corresponds to state.hidden
—the activity of the hidden layer of our network. The output of the layer is used to update state.output
.
Looking back at our generalized definition of __call__
for staged models, each model stage is called like
As is generally the case for Feedbax and Equinox modules, Linear
also has an argument key
—but being a non-stochastic layer, will simply ignore it if passed.
If the component doesn't take a key
argument at all, the solution is similar to what we discuss below.
# Get the updated part of the state associated with the stage
new_component_state = stage.func(
stage.where_input(input, state),
stage.where_state(state),
key_stage,
)
Not including the key, two arguments are passed, after being selected by where_input
and where_state
. Clearly this will raise an error if stage.func
refers to a single-argument module like Linear
.
When we define our model_spec
, we have to be careful in these cases. Here's a sketch of how we could define this stage:
class StagedNetWithReadout(AbstractStagedModel):
readout: eqx.nn.Linear
...
@property
def model_spec(self):
return OrderedDict({
...,
'readout': ModelStage(
func=lambda self: (
lambda input, state, *, key: self.readout(input)
),
where_input=lambda input, state: state.hidden,
where_state=lambda state: state.output,
)
...,
})
We wrap self.readout
inside a second lambda, which passes the where_input(...)
argument on to Linear
, but simply discards the where_state(...)
argument. Note that where_state
still needs to be defined, since it's also used to determine which part of the model's state to update with the output of the linear layer.
If you find double-lambdas hard to read, you might prefer the function feedbax.wrap_stateless_keyless_callable
:
func=lambda self: wrap_stateless_keyless_callable(self.readout),
There's also a function wrap_stateless_callable
for when the key
argument should be passed through:
func=lambda self: wrap_stateless_callable(self.stateless_component),