Skip to content

Staged models¤

feedbax.AbstractStagedModel (AbstractModel[StateT]) ¤

Base class for state-dependent models whose stages can be intervened upon.

Info

To define a new staged model, the following complementary components must be implemented:

  1. A PyTree of model states -- typically, a final subclass of equinox.Module. The fields of the PyTree are typically JAX arrays, or else other PyTrees of model states associated with the model's components.
  2. A final subclass of AbstractStagedModel. Note that the abstract class is a Generic, and for proper type checking, the type argument of the subclass should be the type of state PyTree defined in (1).

    This subclass must implement the following:

    1. A model_spec property giving a mapping from stage labels to ModelStage instances, each specifying an operation performed on the model state.
    2. An init method that takes a random key and returns a default model state.

For an example, consider 1) SimpleFeedbackState and 2) SimpleFeedback.

step: Module property ¤

The model step.

For an AbstractStagedModel, this is trivially the model itself.

state_consistency_update (state: StateT) -> StateT
¤
__call__ (input: ModelInput,state: StateT,key: PRNGKeyArray) -> StateT
¤

Return an updated model state, given input and a prior state.

Parameters:

Name Type Description Default
input ModelInput

The input to the model.

required
state StateT

The prior state of the model.

required
key PRNGKeyArray

A random key which will be split to provide separate keys for each model stage and intervenor.

required
init (*,key: PRNGKeyArray) -> StateT
abstractmethod ¤

Return a default state for the model.

model_spec () -> OrderedDict[str, ModelStage[Self, StateT]]
¤

Specify the model's computation in terms of state operations.

Warning

It's necessary to return OrderedDict because jax.tree_util still sorts dict keys, which usually puts the stages out of order.

feedbax.ModelStage (Module, Generic[ModelT, T]) ¤

Specification for a stage in a subclass of AbstractStagedModel.

Each stage of a model is a callable that performs a modification to part of the model state.

Note

To ensure that references to parts of the model instance remain fresh, callable_ takes the instance of AbstractStagedModel (i.e. self) and returns the callable associated with the stage.

It is possible for references to become stale. For example, if we assign callable_=self.net for the neural network update in SimpleFeedback, then it will continue to refer to the neural network assigned to self.net upon the model's construction, even after the network weights have been updated during training—so, the model will not train.

Attributes:

Name Type Description
callable_

The module, method, or function that transforms part of the model state.

where_input Callable[[PyTree, T], PyTree]

Selects the parts of the input and state to be passed as input to callable_.

where_state Callable[[T], PyTree]

Selects the substate that passed and return as state to callable_.

intervenors StageIntervenors[T]

Optionally, a sequence of state interventions to be applied at the beginning of this model stage.

Pretty printing of model stages¤

feedbax.pprint_model_spec (model: AbstractStagedModel,indent: int = 2,newlines: bool = False)
¤

Prints a string representation of the model specification tree.

Shows what is called by model, and by any AbstractStagedModels it calls.

Warning

This assumes that the model spec is a tree. If there are cycles in the model spec, this will recurse until an exception is raised.

Parameters:

Name Type Description Default
model AbstractStagedModel

The staged model to format.

required
indent int

Number of spaces to indent each nested level of the tree.

2
newlines bool

Whether to add an extra blank line between each line.

False