Skip to content

Interventions¤

feedbax.intervene.CurlFieldParams (AbstractIntervenorInput) ¤

Parameters for a curl force field.

Attributes:

Name Type Description
scale float

Scaling factor on the intervenor output.

active bool

Whether the force field is active.

amplitude float

The amplitude of the force field. Negative is clockwise, positive is counterclockwise.

feedbax.intervene.CurlField (AbstractIntervenor['MechanicsState', CurlFieldParams]) ¤

Apply a curl force field to a mechanical effector.

Attributes:

Name Type Description
params CurlFieldParams

Default curl field parameters.

in_where Callable[[MechanicsState], Float[Array, '... ndim=2']]

Returns the substate corresponding to the effector's velocity.

out_where Callable[[MechanicsState], Float[Array, '... ndim=2']]

Returns the substate corresponding to the force on the effector.

operation Callable[[ArrayLike, ArrayLike], ArrayLike]

How to combine the effector force due to the curl field, with the existing force on the effector. Default is addition.

with_params (**kwargs) -> Self
¤
__call__ (input: InputT,state: StateT,*,key: PRNGKeyArray) -> StateT
¤
transform (params: CurlFieldParams,substate_in: Float[Array, 'ndim=2'],*,key: PRNGKeyArray) -> Float[Array, 'ndim=2']
¤

Transform velocity into curl force.

feedbax.intervene.AddNoise (AbstractIntervenor[StateT, AddNoiseParams]) ¤

Add noise to a part of the state.

Attributes:

Name Type Description
params AddNoiseParams

Default intervention parameters.

out_where Callable[[StateT], PyTree[Array, T]]

Returns the substate to which noise is added.

operation Callable[[ArrayLike, ArrayLike], ArrayLike]

How to combine the noise with the substate. Default is addition.

with_params (**kwargs) -> Self
¤
__call__ (input: InputT,state: StateT,*,key: PRNGKeyArray) -> StateT
¤
transform (params: AddNoiseParams,substate_in: PyTree[Array, T],*,key: PRNGKeyArray) -> PyTree[Array, T]
¤

Return a PyTree of scaled noise arrays with the same structure/shapes as substate_in.

feedbax.intervene.NetworkClamp (AbstractIntervenor['NetworkState', NetworkIntervenorParams]) ¤

Clamps some of a network's units' activities to given values.

Attributes:

Name Type Description
params NetworkIntervenorParams

Default intervention parameters.

out_where Callable[[NetworkState], PyTree[Array, T]]

Returns the substate of arrays giving the activity of the units whose activities may be clamped.

operation Callable[[ArrayLike, ArrayLike], ArrayLike]

How to combine the original and clamped unit activities. Default is to replace the original with the altered.

with_params (**kwargs) -> Self
¤
__call__ (input: InputT,state: StateT,*,key: PRNGKeyArray) -> StateT
¤

feedbax.intervene.NetworkConstantInput (AbstractIntervenor['NetworkState', NetworkIntervenorParams]) ¤

Adds a constant input to some network units.

Attributes:

Name Type Description
params NetworkIntervenorParams

Default intervention parameters.

out_where Callable[[NetworkState], PyTree[Array, T]]

Returns the substate of arrays giving the activity of the units to which a constant input may be added.

operation Callable[[ArrayLike, ArrayLike], ArrayLike]

How to combine the original and altered unit activities. Default is addition.

with_params (**kwargs) -> Self
¤
__call__ (input: InputT,state: StateT,*,key: PRNGKeyArray) -> StateT
¤

Adding interventions to tasks and models¤

feedbax.intervene.schedule_intervenor (tasks: PyTree[AbstractTask],models: PyTree[AbstractModel[StateT]],where: Callable[[AbstractModel[StateT]], Any],intervenor: AbstractIntervenor | Type[AbstractIntervenor],stage_name: Optional[str] = None,default_active: bool = False,label: Optional[str] = None,validation_same_schedule: bool = True,intervenor_params: Optional[AbstractIntervenorInput] = None,intervenor_params_validation: Optional[AbstractIntervenorInput] = None) -> Tuple[PyTree[AbstractTask], PyTree[AbstractModel[StateT]]]
¤

Adds an intervention to a model and a task.

Accepts either an intervenor instance, or an intervenor class. Passing an intervenor instance but no intervenor_params, the instance's params attribute is used as intervenor_params. This can be combined with the intervenor's with_params constructor to define the schedule. For example:

schedule_intervenor(
    tasks,
    models,
    lambda model: model.step.mechanics,
    CurlField.with_params(
        amplitude=lambda trial_spec, *, key: jr.normal(key, (1,)),
        active=True,
    ),
    ...
)

Passing an intervenor class and an intervenor_params, an instance will be constructed from the two.

Passing an intervenor instance and an intervenor_params, the instance's params will be replaced with the intervenor_params before adding to the model.

Passing an intervenor class but no intervenor_params, an error is raised due to insufficient information to schedule the intervention.

Passing a value for intervenor_params_validation allows for separate control over the intervention schedule for the task's validation set.

Parameters:

Name Type Description Default
tasks PyTree[AbstractTask]

The task(s) in whose trials the intervention will be scheduled

required
models PyTree[AbstractModel[StateT]]

The model(s) to which the intervention will be added

required
where Callable[[AbstractModel[StateT]], Any]

Takes model and returns the instance of AbstractStagedModel within it (which may be model itself) to which to add the intervenors.

required
intervenor AbstractIntervenor | Type[AbstractIntervenor]

The intervenor (or intervenor class) to schedule.

required
stage_name Optional[str]

The name of the stage in where(model).model_spec at the end of which the intervenor will be executed. If None, executes before the first stage of the model.

None
validation_same_schedule bool

Whether the interventions should be scheduled in the same way for the validation set as for the training set.

True
intervenor_params Optional[AbstractIntervenorInput]

The parameters of to the intervenor, which may be constants, or callables that are used by task to construct the parameters for the intervention on each trial.

None
intervenor_params_validation Optional[AbstractIntervenorInput]

Same as intervenor_input, but for the task's validation set. Overrides validation_same_schedule.

None
default_active bool

If the intervenor added to the model should have active=True by default, so that the intervention will be turned on even if the intervenor doesn't explicitly receive values for its parameters.

False

feedbax.intervene.add_intervenors (model: AbstractStagedModel[StateT],where: Callable[[AbstractStagedModel[StateT]], Any],intervenors: Union[Sequence[AbstractIntervenor], Mapping[StageNameStr, Union[Sequence[AbstractIntervenor], Mapping[IntervenorLabelStr, AbstractIntervenor]]]],stage_name: Optional[StageNameStr] = None,keep_existing: bool = True) -> AbstractStagedModel[StateT]
¤

Return an updated model with added intervenors.

Parameters:

Name Type Description Default
model AbstractStagedModel[StateT]

The model to which the intervenors will be added.

required
where Callable[[AbstractStagedModel[StateT]], Any]

Takes model and returns the instance of AbstractStagedModel within it (which may be model itself) to which to add the intervenors.

required
intervenors Union[Sequence[AbstractIntervenor], Mapping[StageNameStr, Union[Sequence[AbstractIntervenor], Mapping[IntervenorLabelStr, AbstractIntervenor]]]]

The intervenors to add. May be 1) a sequence of intervenors to execute, by default before the first model stage, or 2) a dict/mapping from stage names to a) the sequence of intervenors to execute at the end of a respective model stage, or b) another dict/mapping from custom intervenor labels to intervenors to execute at the end of that state.

required
stage_name Optional[StageNameStr]

If intervenors is supplied as a simple sequence of intervenors (case 1), execute them at the end of this model stage. By default, they will be executed prior to the first model stage.

None
keep_existing bool

Whether to keep the existing intervenors belonging directly to the instance of AbstractStagedModel to which the new intervenors are added. If True, the new intervenors are appended to the existing ones; if False, the old intervenors are replaced.

True

feedbax.intervene.add_fixed_intervenor (model: AbstractStagedModel[StateT],where: Callable[[AbstractModel[StateT]], Any],intervenor: AbstractIntervenor,stage_name: Optional[StageNameStr] = None,label: Optional[IntervenorLabelStr] = None,**kwargs: Any) -> AbstractStagedModel[StateT]
¤

Return an updated model with an added, fixed intervenor.

Parameters:

Name Type Description Default
model AbstractStagedModel[StateT]

The model to which the intervenor will be added.

required
where Callable[[AbstractModel[StateT]], Any]

Takes model and returns the instance of AbstractStagedModel within it (which may be model itself) to which to add the intervenors.

required
intervenor AbstractIntervenor

The intervenor to add.

required
stage_name Optional[StageNameStr]

The stage named in model.model_spec to which the intervenor will be added. The intervenor will execute at the end of the stage. If None, the intervenor will execute before the first model stage.

None
intervenor_label

Custom key for the intervenor, which determines how it will be accessed as part of the model PyTree. Note that labels for fixed intervenors are prepended with "FIXED_" if they are not already.

required
kwargs Any

Additional keyword arguments to add_intervenors.

{}

Base classes¤

feedbax.intervene.AbstractIntervenorInput (Module) ¤

Base class for PyTrees of intervention parameters.

Attributes:

Name Type Description
active AbstractVar[bool]

Whether the intervention is active.

scale AbstractVar[float]

Factor by which the intervenor output is scaled.

feedbax.intervene.InputT module-attribute ¤

InputT = TypeVar('InputT', bound=AbstractIntervenorInput)

feedbax.intervene.AbstractIntervenor (Module, Generic[StateT, InputT]) ¤

Base class for modules that intervene on a model's state.

Attributes:

Name Type Description
params AbstractVar[InputT]

Default intervention parameters.

in_where AbstractVar[Callable[[StateT], PyTree[ArrayLike, T]]]

Takes an instance of the model state, and returns the substate corresponding to the intervenor's input.

out_where AbstractVar[Callable[[StateT], PyTree[ArrayLike, S]]]

Takes an instance of the model state, and returns the substate corresponding to the intervenor's output. In many cases, out_where will be the same as in_where.

operation AbstractVar[Callable[[ArrayLike, ArrayLike], ArrayLike]]

Which operation to use to combine the original and altered out_where substates. For example, an intervenor that clamps a state variable to a particular value should use an operation like lambda x, y: y to replace the original with the altered state. On the other hand, an additive intervenor would use the equivalent of lambda x, y: x + y.

__call__ (input: InputT,state: StateT,*,key: PRNGKeyArray) -> StateT
¤

Return a state PyTree modified by the intervention.

Parameters:

Name Type Description Default
input InputT

PyTree of intervention parameters. If any leaves are None, they will be replaced by the corresponding leaves of self.params.

required
state StateT

The model state to be intervened upon.

required
key PRNGKeyArray

A key to provide randomness for the intervention.

required
transform (params: InputT,substate_in: PyTree[ArrayLike, T],*,key: PRNGKeyArray) -> PyTree[ArrayLike, S]
abstractmethod ¤

Transforms the input substate to produce an altered output substate.

with_params (**kwargs) -> Self
classmethod ¤

Constructor that accepts field names of InputT as keywords.

This is a convenience so we don't need to import the parameter class, to instantiate an intervenor class it is associated with.

Example

CurlField.with_params(amplitude=10.0)