Staged models¤
feedbax.AbstractStagedModel
(AbstractModel[StateT])
¤
Base class for state-dependent models whose stages can be intervened upon.
Info
To define a new staged model, the following complementary components must be implemented:
- A PyTree of model states -- typically, a
final subclass of
equinox.Module
. The fields of the PyTree are typically JAX arrays, or else other PyTrees of model states associated with the model's components. -
A final subclass of
AbstractStagedModel
. Note that the abstract class is aGeneric
, and for proper type checking, the type argument of the subclass should be the type of state PyTree defined in (1).This subclass must implement the following:
- A
model_spec
property giving a mapping from stage labels toModelStage
instances, each specifying an operation performed on the model state. - An
init
method that takes a random key and returns a default model state.
- A
For an example, consider 1) SimpleFeedbackState
and 2) SimpleFeedback
.
step: Module
property
¤
The model step.
For an AbstractStagedModel
, this is trivially the model itself.
state_consistency_update
(
state
: StateT
)
->
StateT
¤
state_consistency_update
(
state
: StateT
)
->
StateT
Inherited from feedbax._model.AbstractModel
__call__
(
input
: ModelInput
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: ModelInput
,
state
: StateT
,
key
: PRNGKeyArray
)
->
StateT
Return an updated model state, given input and a prior state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
ModelInput
|
The input to the model. |
required |
state |
StateT
|
The prior state of the model. |
required |
key |
PRNGKeyArray
|
A random key which will be split to provide separate keys for each model stage and intervenor. |
required |
init
(
*,
key
: PRNGKeyArray
)
->
StateT
abstractmethod
¤
init
(
*,
key
: PRNGKeyArray
)
->
StateT
Return a default state for the model.
model_spec
(
)
->
OrderedDict[str, ModelStage[Self, StateT]]
¤
model_spec
(
)
->
OrderedDict[str, ModelStage[Self, StateT]]
Specify the model's computation in terms of state operations.
Warning
It's necessary to return OrderedDict
because jax.tree_util
still sorts dict
keys, which usually puts the stages out of order.
feedbax.ModelStage
(Module, Generic[ModelT, T])
¤
Specification for a stage in a subclass of AbstractStagedModel
.
Each stage of a model is a callable that performs a modification to part of the model state.
Note
To ensure that references to parts of the model instance remain fresh,
callable_
takes the instance of AbstractStagedModel
(i.e. self
)
and returns the callable associated with the stage.
It is possible for references to become stale. For example, if we
assign callable_=self.net
for the neural network update in
SimpleFeedback
, then it will
continue to refer to the neural network assigned to self.net
upon the model's construction, even after the network weights
have been updated during training—so, the model will not train.
Attributes:
Name | Type | Description |
---|---|---|
callable_ |
The module, method, or function that transforms part of the model state. |
|
where_input |
Callable[[PyTree, T], PyTree]
|
Selects the parts of the input and state to be passed
as input to |
where_state |
Callable[[T], PyTree]
|
Selects the substate that passed and return as state to
|
intervenors |
StageIntervenors[T]
|
Optionally, a sequence of state interventions to be applied at the beginning of this model stage. |
Pretty printing of model stages¤
feedbax.pprint_model_spec
(
model
: AbstractStagedModel
,
indent
: int = 2
,
newlines
: bool = False
)
¤
feedbax.pprint_model_spec
(
model
: AbstractStagedModel
,
indent
: int = 2
,
newlines
: bool = False
)
Prints a string representation of the model specification tree.
Shows what is called by model
, and by any AbstractStagedModel
s it calls.
Warning
This assumes that the model spec is a tree. If there are cycles in the model spec, this will recurse until an exception is raised.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
AbstractStagedModel
|
The staged model to format. |
required |
indent |
int
|
Number of spaces to indent each nested level of the tree. |
2
|
newlines |
bool
|
Whether to add an extra blank line between each line. |
False
|