Interventions¤
feedbax.intervene.CurlFieldParams
(AbstractIntervenorInput)
¤
Parameters for a curl force field.
Attributes:
Name | Type | Description |
---|---|---|
scale |
float
|
Scaling factor on the intervenor output. |
active |
bool
|
Whether the force field is active. |
amplitude |
float
|
The amplitude of the force field. Negative is clockwise, positive is counterclockwise. |
feedbax.intervene.CurlField
(AbstractIntervenor['MechanicsState', CurlFieldParams])
¤
Apply a curl force field to a mechanical effector.
Attributes:
Name | Type | Description |
---|---|---|
params |
CurlFieldParams
|
Default curl field parameters. |
in_where |
Callable[[MechanicsState], Float[Array, '... ndim=2']]
|
Returns the substate corresponding to the effector's velocity. |
out_where |
Callable[[MechanicsState], Float[Array, '... ndim=2']]
|
Returns the substate corresponding to the force on the effector. |
operation |
Callable[[ArrayLike, ArrayLike], ArrayLike]
|
How to combine the effector force due to the curl field, with the existing force on the effector. Default is addition. |
with_params
(
**kwargs
)
->
Self
¤
with_params
(
**kwargs
)
->
Self
Inherited from feedbax.intervene.intervene.AbstractIntervenor
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax.intervene.intervene.AbstractIntervenor
transform
(
params
: CurlFieldParams
,
substate_in
: Float[Array, 'ndim=2']
,
*,
key
: PRNGKeyArray
)
->
Float[Array, 'ndim=2']
¤
transform
(
params
: CurlFieldParams
,
substate_in
: Float[Array, 'ndim=2']
,
*,
key
: PRNGKeyArray
)
->
Float[Array, 'ndim=2']
Transform velocity into curl force.
feedbax.intervene.AddNoise
(AbstractIntervenor[StateT, AddNoiseParams])
¤
Add noise to a part of the state.
Attributes:
Name | Type | Description |
---|---|---|
params |
AddNoiseParams
|
Default intervention parameters. |
out_where |
Callable[[StateT], PyTree[Array, T]]
|
Returns the substate to which noise is added. |
operation |
Callable[[ArrayLike, ArrayLike], ArrayLike]
|
How to combine the noise with the substate. Default is addition. |
with_params
(
**kwargs
)
->
Self
¤
with_params
(
**kwargs
)
->
Self
Inherited from feedbax.intervene.intervene.AbstractIntervenor
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax.intervene.intervene.AbstractIntervenor
transform
(
params
: AddNoiseParams
,
substate_in
: PyTree[Array, T]
,
*,
key
: PRNGKeyArray
)
->
PyTree[Array, T]
¤
transform
(
params
: AddNoiseParams
,
substate_in
: PyTree[Array, T]
,
*,
key
: PRNGKeyArray
)
->
PyTree[Array, T]
Return a PyTree of scaled noise arrays with the same structure/shapes as
substate_in
.
feedbax.intervene.NetworkClamp
(AbstractIntervenor['NetworkState', NetworkIntervenorParams])
¤
Clamps some of a network's units' activities to given values.
Attributes:
Name | Type | Description |
---|---|---|
params |
NetworkIntervenorParams
|
Default intervention parameters. |
out_where |
Callable[[NetworkState], PyTree[Array, T]]
|
Returns the substate of arrays giving the activity of the units whose activities may be clamped. |
operation |
Callable[[ArrayLike, ArrayLike], ArrayLike]
|
How to combine the original and clamped unit activities. Default is to replace the original with the altered. |
with_params
(
**kwargs
)
->
Self
¤
with_params
(
**kwargs
)
->
Self
Inherited from feedbax.intervene.intervene.AbstractIntervenor
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax.intervene.intervene.AbstractIntervenor
feedbax.intervene.NetworkConstantInput
(AbstractIntervenor['NetworkState', NetworkIntervenorParams])
¤
Adds a constant input to some network units.
Attributes:
Name | Type | Description |
---|---|---|
params |
NetworkIntervenorParams
|
Default intervention parameters. |
out_where |
Callable[[NetworkState], PyTree[Array, T]]
|
Returns the substate of arrays giving the activity of the units to which a constant input may be added. |
operation |
Callable[[ArrayLike, ArrayLike], ArrayLike]
|
How to combine the original and altered unit activities. Default is addition. |
with_params
(
**kwargs
)
->
Self
¤
with_params
(
**kwargs
)
->
Self
Inherited from feedbax.intervene.intervene.AbstractIntervenor
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax.intervene.intervene.AbstractIntervenor
Adding interventions to tasks and models¤
feedbax.intervene.schedule_intervenor
(
tasks
: PyTree[AbstractTask]
,
models
: PyTree[AbstractModel[StateT]]
,
where
: Callable[[AbstractModel[StateT]], Any]
,
intervenor
: AbstractIntervenor | Type[AbstractIntervenor]
,
stage_name
: Optional[str] = None
,
default_active
: bool = False
,
label
: Optional[str] = None
,
validation_same_schedule
: bool = True
,
intervenor_params
: Optional[AbstractIntervenorInput] = None
,
intervenor_params_validation
: Optional[AbstractIntervenorInput] = None
)
->
Tuple[PyTree[AbstractTask], PyTree[AbstractModel[StateT]]]
¤
feedbax.intervene.schedule_intervenor
(
tasks
: PyTree[AbstractTask]
,
models
: PyTree[AbstractModel[StateT]]
,
where
: Callable[[AbstractModel[StateT]], Any]
,
intervenor
: AbstractIntervenor | Type[AbstractIntervenor]
,
stage_name
: Optional[str] = None
,
default_active
: bool = False
,
label
: Optional[str] = None
,
validation_same_schedule
: bool = True
,
intervenor_params
: Optional[AbstractIntervenorInput] = None
,
intervenor_params_validation
: Optional[AbstractIntervenorInput] = None
)
->
Tuple[PyTree[AbstractTask], PyTree[AbstractModel[StateT]]]
Adds an intervention to a model and a task.
Accepts either an intervenor instance, or an intervenor class. Passing
an intervenor instance but no intervenor_params
, the instance's
params
attribute is used as intervenor_params
. This can be combined
with the intervenor's with_params
constructor to define the
schedule. For example:
schedule_intervenor(
tasks,
models,
lambda model: model.step.mechanics,
CurlField.with_params(
amplitude=lambda trial_spec, *, key: jr.normal(key, (1,)),
active=True,
),
...
)
Passing an intervenor class and an intervenor_params
, an instance
will be constructed from the two.
Passing an intervenor instance and an intervenor_params
, the
instance's params
will be replaced with the intervenor_params
before adding to the model.
Passing an intervenor class but no intervenor_params
, an error is
raised due to insufficient information to schedule the intervention.
Passing a value for intervenor_params_validation
allows for separate
control over the intervention schedule for the task's validation set.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tasks |
PyTree[AbstractTask]
|
The task(s) in whose trials the intervention will be scheduled |
required |
models |
PyTree[AbstractModel[StateT]]
|
The model(s) to which the intervention will be added |
required |
where |
Callable[[AbstractModel[StateT]], Any]
|
Takes |
required |
intervenor |
AbstractIntervenor | Type[AbstractIntervenor]
|
The intervenor (or intervenor class) to schedule. |
required |
stage_name |
Optional[str]
|
The name of the stage in |
None
|
validation_same_schedule |
bool
|
Whether the interventions should be scheduled in the same way for the validation set as for the training set. |
True
|
intervenor_params |
Optional[AbstractIntervenorInput]
|
The parameters of to the intervenor, which may be
constants, or callables that are used by |
None
|
intervenor_params_validation |
Optional[AbstractIntervenorInput]
|
Same as |
None
|
default_active |
bool
|
If the intervenor added to the model should have
|
False
|
feedbax.intervene.add_intervenors
(
model
: AbstractStagedModel[StateT]
,
where
: Callable[[AbstractStagedModel[StateT]], Any]
,
intervenors
: Union[Sequence[AbstractIntervenor], Mapping[StageNameStr, Union[Sequence[AbstractIntervenor], Mapping[IntervenorLabelStr, AbstractIntervenor]]]]
,
stage_name
: Optional[StageNameStr] = None
,
keep_existing
: bool = True
)
->
AbstractStagedModel[StateT]
¤
feedbax.intervene.add_intervenors
(
model
: AbstractStagedModel[StateT]
,
where
: Callable[[AbstractStagedModel[StateT]], Any]
,
intervenors
: Union[Sequence[AbstractIntervenor], Mapping[StageNameStr, Union[Sequence[AbstractIntervenor], Mapping[IntervenorLabelStr, AbstractIntervenor]]]]
,
stage_name
: Optional[StageNameStr] = None
,
keep_existing
: bool = True
)
->
AbstractStagedModel[StateT]
Return an updated model with added intervenors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
AbstractStagedModel[StateT]
|
The model to which the intervenors will be added. |
required |
where |
Callable[[AbstractStagedModel[StateT]], Any]
|
Takes |
required |
intervenors |
Union[Sequence[AbstractIntervenor], Mapping[StageNameStr, Union[Sequence[AbstractIntervenor], Mapping[IntervenorLabelStr, AbstractIntervenor]]]]
|
The intervenors to add. May be 1) a sequence of intervenors to execute, by default before the first model stage, or 2) a dict/mapping from stage names to a) the sequence of intervenors to execute at the end of a respective model stage, or b) another dict/mapping from custom intervenor labels to intervenors to execute at the end of that state. |
required |
stage_name |
Optional[StageNameStr]
|
If |
None
|
keep_existing |
bool
|
Whether to keep the existing intervenors belonging directly to
the instance of |
True
|
feedbax.intervene.add_fixed_intervenor
(
model
: AbstractStagedModel[StateT]
,
where
: Callable[[AbstractModel[StateT]], Any]
,
intervenor
: AbstractIntervenor
,
stage_name
: Optional[StageNameStr] = None
,
label
: Optional[IntervenorLabelStr] = None
,
**kwargs
: Any
)
->
AbstractStagedModel[StateT]
¤
feedbax.intervene.add_fixed_intervenor
(
model
: AbstractStagedModel[StateT]
,
where
: Callable[[AbstractModel[StateT]], Any]
,
intervenor
: AbstractIntervenor
,
stage_name
: Optional[StageNameStr] = None
,
label
: Optional[IntervenorLabelStr] = None
,
**kwargs
: Any
)
->
AbstractStagedModel[StateT]
Return an updated model with an added, fixed intervenor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
AbstractStagedModel[StateT]
|
The model to which the intervenor will be added. |
required |
where |
Callable[[AbstractModel[StateT]], Any]
|
Takes |
required |
intervenor |
AbstractIntervenor
|
The intervenor to add. |
required |
stage_name |
Optional[StageNameStr]
|
The stage named in |
None
|
intervenor_label |
Custom key for the intervenor, which determines how it will be
accessed as part of the model PyTree. Note that labels for fixed intervenors
are prepended with |
required | |
kwargs |
Any
|
Additional keyword arguments to
|
{}
|
Base classes¤
feedbax.intervene.AbstractIntervenorInput
(Module)
¤
Base class for PyTrees of intervention parameters.
Attributes:
Name | Type | Description |
---|---|---|
active |
AbstractVar[bool]
|
Whether the intervention is active. |
scale |
AbstractVar[float]
|
Factor by which the intervenor output is scaled. |
feedbax.intervene.InputT
module-attribute
¤
InputT = TypeVar('InputT', bound=AbstractIntervenorInput)
feedbax.intervene.AbstractIntervenor
(Module, Generic[StateT, InputT])
¤
Base class for modules that intervene on a model's state.
Attributes:
Name | Type | Description |
---|---|---|
params |
AbstractVar[InputT]
|
Default intervention parameters. |
in_where |
AbstractVar[Callable[[StateT], PyTree[ArrayLike, T]]]
|
Takes an instance of the model state, and returns the substate corresponding to the intervenor's input. |
out_where |
AbstractVar[Callable[[StateT], PyTree[ArrayLike, S]]]
|
Takes an instance of the model state, and returns the substate
corresponding to the intervenor's output. In many cases, |
operation |
AbstractVar[Callable[[ArrayLike, ArrayLike], ArrayLike]]
|
Which operation to use to combine the original and altered
|
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
¤
__call__
(
input
: InputT
,
state
: StateT
,
*,
key
: PRNGKeyArray
)
->
StateT
Return a state PyTree modified by the intervention.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
InputT
|
PyTree of intervention parameters. If any leaves are |
required |
state |
StateT
|
The model state to be intervened upon. |
required |
key |
PRNGKeyArray
|
A key to provide randomness for the intervention. |
required |
transform
(
params
: InputT
,
substate_in
: PyTree[ArrayLike, T]
,
*,
key
: PRNGKeyArray
)
->
PyTree[ArrayLike, S]
abstractmethod
¤
transform
(
params
: InputT
,
substate_in
: PyTree[ArrayLike, T]
,
*,
key
: PRNGKeyArray
)
->
PyTree[ArrayLike, S]
Transforms the input substate to produce an altered output substate.
with_params
(
**kwargs
)
->
Self
classmethod
¤
with_params
(
**kwargs
)
->
Self
Constructor that accepts field names of InputT
as keywords.
This is a convenience so we don't need to import the parameter class, to instantiate an intervenor class it is associated with.
Example
CurlField.with_params(amplitude=10.0)