Skip to content

Tasks¤

Feedbax tasks are objects that group together:

  1. A loss function that is used to evaluated performance on a task;
  2. Per-trial data to:
    1. Initialize the state of a model prior to evaluation on task trials;
    2. Specify the parameters of task trials to the model and to the loss function.

Reaching¤

Simple reaching¤

feedbax.task.SimpleReachTaskInputs (Module) ¤

Model input for a simple reaching task.

Attributes:

Name Type Description
effector_target CartesianState

The trajectory of effector target states to be presented to the model.

feedbax.task.SimpleReachTrialSpec (AbstractReachTrialSpec) ¤

Trial specification for a simple reaching task.

Attributes:

Name Type Description
inits InitsWhereDict

A mapping from lambdas that select model substates to be initialized, to substates to initialize them with at the start of trials.

inputs SimpleReachTaskInputs

For providing the model with the reach target.

target CartesianState

The target trajectory for the mechanical end effector.

intervene Mapping[IntervenorLabelStr, AbstractIntervenorInput]

A mapping from unique intervenor names, to per-trial intervention parameters.

feedbax.task.SimpleReaches (AbstractTask) ¤

Reaches between random endpoints in a rectangular workspace. No hold signal.

Validation set is center-out reaches.

Note

This passes a trajectory of target velocities all equal to zero, assuming that the user will choose a loss function that penalizes only the initial or final velocities. If the loss function penalizes the intervening velocities, this task no longer makes sense as a reaching task.

Attributes:

Name Type Description
n_steps int

The number of time steps in each task trial.

loss_func AbstractLoss

The loss function that grades performance on each trial.

workspace Float[Array, 'bounds=2 ndim=2']

The rectangular workspace in which the reaches are distributed.

seed_validation int

The random seed for generating the validation trials.

intervention_specs TaskInterventionSpecs

A mapping from unique intervenor names, to specifications for generating per-trial intervention parameters on training trials.

intervention_specs_validation TaskInterventionSpecs

A mapping from unique intervenor names, to specifications for generating per-trial intervention parameters on validation trials.

eval_grid_n int

The number of evenly-spaced internal grid points of the workspace at which a set of center-out reach is placed.

eval_n_directions int

The number of evenly-spread center-out reaches starting from each workspace grid point in the validation set. The number of trials in the validation set is equal to eval_n_directions * eval_grid_n ** 2.

eval_reach_length float

The length (in space) of each reach in the validation set.

n_validation_trials: int property ¤

Number of trials in the validation set.

get_train_trial_with_intervenor_params (key: PRNGKeyArray) -> AbstractTaskTrialSpec
¤

Inherited from feedbax.task.AbstractTask

eval_trials (model: AbstractModel[StateT],trial_specs: AbstractTaskTrialSpec,keys: PRNGKeyArray) -> Tuple[StateT, LossDict]
¤

Inherited from feedbax.task.AbstractTask

eval_with_loss (model: AbstractModel[StateT],key: PRNGKeyArray) -> Tuple[StateT, LossDict]
¤

Inherited from feedbax.task.AbstractTask

eval (model: AbstractModel[StateT],key: PRNGKeyArray) -> StateT
¤

Inherited from feedbax.task.AbstractTask

eval_ensemble_with_loss (models: AbstractModel[StateT],n_replicates: int,key: PRNGKeyArray,ensemble_random_trials: bool = True) -> tuple[StateT, LossDict]
¤

Inherited from feedbax.task.AbstractTask

eval_ensemble (models: AbstractModel[StateT],n_replicates: int,key: PRNGKeyArray,ensemble_random_trials: bool = True) -> StateT
¤

Inherited from feedbax.task.AbstractTask

eval_train_batch (model: AbstractModel[StateT],batch_size: int,key: PRNGKeyArray) -> Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤

Inherited from feedbax.task.AbstractTask

eval_ensemble_train_batch (models: AbstractModel[StateT],n_replicates: int,batch_size: int,key: PRNGKeyArray,ensemble_random_trials: bool = True) -> Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤

Inherited from feedbax.task.AbstractTask

add_intervenors_to_base_model (model: AbstractStagedModel[StateT]) -> AbstractStagedModel[StateT]
¤

Inherited from feedbax.task.AbstractTask

activate_interventions (labels: NonCharSequence[IntervenorLabelStr] | Literal['all', 'none'],labels_validation: Optional[NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']] = None,validation_same_schedule) -> Self
¤

Inherited from feedbax.task.AbstractTask

get_train_trial (key: PRNGKeyArray) -> SimpleReachTrialSpec
¤

Random reach endpoints across the rectangular workspace.

Parameters:

Name Type Description Default
key PRNGKeyArray

A random key for generating the trial.

required
get_validation_trials (key: PRNGKeyArray) -> SimpleReachTrialSpec
¤

Center-out reach sets in a grid across the rectangular workspace.

Delayed (cued) reaching¤

feedbax.task.DelayedReachTaskInputs (Module) ¤

Model input for a delayed reaching task.

Attributes:

Name Type Description
effector_target CartesianState

The trajectory of effector target states to be presented to the model.

hold Int[Array, 'time 1']

The hold/go (1/0 signal) to be presented to the model.

target_on Int[Array, 'time 1']

A signal indicating to the model when the value of effector_target should be interpreted as a reach target. Otherwise, if zeros are passed for the target during (say) the hold period, the model may interpret this as meaningful—that is, "your reach target is at 0".

feedbax.task.DelayedReachTrialSpec (AbstractReachTrialSpec) ¤

Trial specification for a delayed reaching task.

Attributes:

Name Type Description
inits WhereDict

A mapping from lambdas that select model substates to be initialized, to substates to initialize them with at the start of trials.

inputs DelayedReachTaskInputs

For providing the model with the reach target and hold signal.

target CartesianState

The target trajectory for the mechanical end effector.

epoch_start_idxs Int[Array, n_epochs]

The indices of the start of each epoch in the trial.

intervene Mapping[IntervenorLabelStr, AbstractIntervenorInput]

A mapping from unique intervenor names, to per-trial intervention parameters.

feedbax.task.DelayedReaches (AbstractTask) ¤

Uniform random endpoints in a rectangular workspace.

e.g. allows for a stimulus epoch, followed by a delay period, then movement.

Attributes:

Name Type Description
loss_func AbstractLoss

The loss function that grades performance on each trial.

workspace Float[Array, 'bounds=2 ndim=2']

The rectangular workspace in which the reaches are distributed.

n_steps int

The number of time steps in each task trial.

epoch_len_ranges Tuple[Tuple[int, int], ...]

The ranges from which to uniformly sample the durations of the task phases for each task trial.

target_on_epochs Int[Array, _]

The epochs in which the "target on" signal is turned on.

hold_epochs Int[Array, _]

The epochs in which the hold signal is turned on.

eval_n_directions int

The number of evenly-spread center-out reaches starting from each workspace grid point in the validation set. The number of trials in the validation set is equal to eval_n_directions * eval_grid_n ** 2.

eval_reach_length float

The length (in space) of each reach in the validation set.

eval_grid_n int

The number of evenly-spaced internal grid points of the workspace at which a set of center-out reach is placed.

seed_validation int

The random seed for generating the validation trials.

n_validation_trials: int property ¤

Number of trials in the validation set.

get_train_trial_with_intervenor_params (key: PRNGKeyArray) -> AbstractTaskTrialSpec
¤

Inherited from feedbax.task.AbstractTask

eval_trials (model: AbstractModel[StateT],trial_specs: AbstractTaskTrialSpec,keys: PRNGKeyArray) -> Tuple[StateT, LossDict]
¤

Inherited from feedbax.task.AbstractTask

eval_with_loss (model: AbstractModel[StateT],key: PRNGKeyArray) -> Tuple[StateT, LossDict]
¤

Inherited from feedbax.task.AbstractTask

eval (model: AbstractModel[StateT],key: PRNGKeyArray) -> StateT
¤

Inherited from feedbax.task.AbstractTask

eval_ensemble_with_loss (models: AbstractModel[StateT],n_replicates: int,key: PRNGKeyArray,ensemble_random_trials: bool = True) -> tuple[StateT, LossDict]
¤

Inherited from feedbax.task.AbstractTask

eval_ensemble (models: AbstractModel[StateT],n_replicates: int,key: PRNGKeyArray,ensemble_random_trials: bool = True) -> StateT
¤

Inherited from feedbax.task.AbstractTask

eval_train_batch (model: AbstractModel[StateT],batch_size: int,key: PRNGKeyArray) -> Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤

Inherited from feedbax.task.AbstractTask

eval_ensemble_train_batch (models: AbstractModel[StateT],n_replicates: int,batch_size: int,key: PRNGKeyArray,ensemble_random_trials: bool = True) -> Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤

Inherited from feedbax.task.AbstractTask

add_intervenors_to_base_model (model: AbstractStagedModel[StateT]) -> AbstractStagedModel[StateT]
¤

Inherited from feedbax.task.AbstractTask

activate_interventions (labels: NonCharSequence[IntervenorLabelStr] | Literal['all', 'none'],labels_validation: Optional[NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']] = None,validation_same_schedule) -> Self
¤

Inherited from feedbax.task.AbstractTask

validation_plots (states,trial_specs: Optional[AbstractTaskTrialSpec] = None) -> Mapping[str, go.Figure]
¤

Inherited from feedbax.task.AbstractTask

get_train_trial (key: PRNGKeyArray) -> DelayedReachTrialSpec
¤

Random reach endpoints across the rectangular workspace.

Parameters:

Name Type Description Default
key PRNGKeyArray

A random key for generating the trial.

required
get_validation_trials (key: PRNGKeyArray) -> DelayedReachTrialSpec
¤

Center-out reach sets in a grid across the rectangular workspace.

Abstract base classes¤

feedbax.task.AbstractTaskTrialSpec (Module) ¤

Abstract base class for trial specifications provided by a task.

Attributes:

Name Type Description
inits AbstractVar[InitsWhereDict]

A mapping from lambdas that select model substates to be initialized, to substates to initialize them with.

inputs AbstractVar[PyTree]

A PyTree of inputs to the model.

target AbstractVar[PyTree[Array]]

A PyTree of target states.

intervene AbstractVar[Mapping[IntervenorLabelStr, AbstractIntervenorInput]]

A mapping from unique intervenor names, to per-trial intervention parameters.

feedbax.task.AbstractTask (Module) ¤

Abstract base class for tasks.

Provides methods for evaluating suitable models or ensembles of models on training and validation trials.

Subclasses must provide:

  • a method that generates training trials
  • a property that provides a set of validation trials
  • a field for a loss function that grades performance on the task

Attributes:

Name Type Description
loss_func AbstractVar[AbstractLoss]

The loss function that grades task performance.

n_steps AbstractVar[int]

The number of time steps in the task trials.

seed_validation AbstractVar[int]

The random seed for generating the validation trials.

intervention_specs AbstractVar[TaskInterventionSpecs]

Mappings from unique intervenor names, to specifications for generating per-trial intervention parameters. Distinct fields provide mappings for training and validation trials, though the two may be identical depending on scheduling.

validation_trials: AbstractTaskTrialSpec cached property ¤

The set of validation trials associated with the task.

get_train_trial (key: PRNGKeyArray) -> AbstractTaskTrialSpec
abstractmethod ¤

Return a single training trial specification.

Parameters:

Name Type Description Default
key PRNGKeyArray

A random key for generating the trial.

required
get_train_trial_with_intervenor_params (key: PRNGKeyArray) -> AbstractTaskTrialSpec
¤

Return a single training trial specification, including intervention parameters.

Parameters:

Name Type Description Default
key PRNGKeyArray

A random key for generating the trial.

required
get_validation_trials (key: PRNGKeyArray) -> AbstractTaskTrialSpec
abstractmethod ¤

Return a set of validation trials, given a random key.

Subclasses must override this method. However, the validation used during training and provided by self.validation_set will be determined by the field self.seed_validation, which must also be implemented by subclasses.

Parameters:

Name Type Description Default
key PRNGKeyArray

A random key for generating the validation set.

required
n_validation_trials () -> int
¤

Number of trials in the validation set.

eval_trials (model: AbstractModel[StateT],trial_specs: AbstractTaskTrialSpec,keys: PRNGKeyArray) -> Tuple[StateT, LossDict]
¤

Evaluate a model on a set of trials.

Parameters:

Name Type Description Default
model AbstractModel[StateT]

The model to evaluate.

required
trial_specs AbstractTaskTrialSpec

The set of trials to evaluate the model on.

required
keys PRNGKeyArray

For providing randomness during model evaluation.

required
eval_with_loss (model: AbstractModel[StateT],key: PRNGKeyArray) -> Tuple[StateT, LossDict]
¤

Evaluate a model on the task's validation set of trials.

Parameters:

Name Type Description Default
model AbstractModel[StateT]

The model to evaluate.

required
key PRNGKeyArray

For providing randomness during model evaluation.

required

Returns:

Type Description
StateT

The losses for the trials in the validation set.

LossDict

The evaluated model states.

eval (model: AbstractModel[StateT],key: PRNGKeyArray) -> StateT
¤

Return states for a model evaluated on the tasks's set of validation trials.

Parameters:

Name Type Description Default
model AbstractModel[StateT]

The model to evaluate.

required
key PRNGKeyArray

For providing randomness during model evaluation.

required
eval_ensemble_with_loss (models: AbstractModel[StateT],n_replicates: int,key: PRNGKeyArray,ensemble_random_trials: bool = True) -> tuple[StateT, LossDict]
¤

Return states and losses for an ensemble of models evaluated on the tasks's set of validation trials.

Parameters:

Name Type Description Default
models AbstractModel[StateT]

The ensemble of models to evaluate.

required
n_replicates int

The number of models in the ensemble.

required
key PRNGKeyArray

For providing randomness during model evaluation. Will be split into n_replicates keys.

required
ensemble_random_trials bool

If False, each model in the ensemble will be evaluated on the same set of trials.

True
eval_ensemble (models: AbstractModel[StateT],n_replicates: int,key: PRNGKeyArray,ensemble_random_trials: bool = True) -> StateT
¤

Return states for an ensemble of models evaluated on the tasks's set of validation trials.

Parameters:

Name Type Description Default
models AbstractModel[StateT]

The ensemble of models to evaluate.

required
n_replicates int

The number of models in the ensemble.

required
key PRNGKeyArray

For providing randomness during model evaluation. Will be split into n_replicates keys.

required
ensemble_random_trials bool

If False, each model in the ensemble will be evaluated on the same set of trials.

True
eval_train_batch (model: AbstractModel[StateT],batch_size: int,key: PRNGKeyArray) -> Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤

Evaluate a model on a single batch of training trials.

Parameters:

Name Type Description Default
model AbstractModel[StateT]

The model to evaluate.

required
batch_size int

The number of trials in the batch.

required
key PRNGKeyArray

For providing randomness during model evaluation.

required

Returns:

Type Description
StateT

The losses for the trials in the batch.

LossDict

The evaluated model states.

AbstractTaskTrialSpec

The trial specifications for the batch.

eval_ensemble_train_batch (models: AbstractModel[StateT],n_replicates: int,batch_size: int,key: PRNGKeyArray,ensemble_random_trials: bool = True) -> Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤

Evaluate an ensemble of models on a single training batch.

Parameters:

Name Type Description Default
models AbstractModel[StateT]

The ensemble of models to evaluate.

required
n_replicates int

The number of models in the ensemble.

required
batch_size int

The number of trials in the batch to evaluate.

required
key PRNGKeyArray

For providing randomness during model evaluation.

required
ensemble_random_trials bool

If False, each model in the ensemble will be evaluated on the same set of trials.

True

Returns:

Type Description
StateT

The losses for the trials in the batch, for each model in the ensemble.

LossDict

The evaluated model states, for each trial and each model in the ensemble.

AbstractTaskTrialSpec

The trial specifications for the batch.

add_intervenors_to_base_model (model: AbstractStagedModel[StateT]) -> AbstractStagedModel[StateT]
¤

Add the task's scheduled intervenors to a model.

Assumes that the model has the appropriate structure to admit the intervention. This depends on the the where and stage_name properties stored in the task's intervention_specs field, since the original call to schedule_intervenor that added them to the task.

  • where should pick out an AbstractStagedModel component of model.
  • If defined, stage_name should be the name of one of the stages of the AbstractStagedModel component picked out by where.

Any existing intervenors in the model that were scheduled with another task, will be removed to prevent conflicts. Other intervenors which were added directly to the model without being scheduled with a task will not be removed.

Note

This method is mostly useful when evaluating a trained model on a task with a different set of interventions than the one it was trained on.

activate_interventions (labels: NonCharSequence[IntervenorLabelStr] | Literal['all', 'none'],labels_validation: Optional[NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']] = None,validation_same_schedule) -> Self
¤

Return a task where scheduling if active only for the interventions with the given labels.

validation_plots (states,trial_specs: Optional[AbstractTaskTrialSpec] = None) -> Mapping[str, go.Figure]
abstractmethod ¤

Returns a basic set of plots to visualize performance on the task.

Useful functions for building tasks¤

feedbax.task.internal_grid_points (bounds: Float[Array, 'bounds=2 ndim=2'],n: int = 2) -> Float[Array, 'n**ndim ndim=2']
¤

Return a list of evenly-spaced grid points internal to the bounds.

Parameters:

Name Type Description Default
bounds Float[Array, 'bounds=2 ndim=2']

The outer bounds of the grid.

required
n int

The number of internal grid points along each dimension.

2

Example

internal_grid_points(
    bounds=((0, 0), (9, 9)),
    n=2,
)
>> Array([[3., 3.], [6., 3.], [3., 6.], [6., 6.]]).

Using lambda functions as dictionary keys¤

feedbax.task.WhereDict ¤

An OrderedDict that allows limited use of where lambdas as keys.

In particular, keys can be lambdas that take a single argument, and return a single (nested) attribute accessed from that argument.

Lambdas are parsed to equivalent strings, which can be used interchangeably as keys. For example, the following are equivalent when init_spec is a WhereDict:

init_spec[lambda state: state.mechanics.effector]
init_spec['mechanics.effector']

Finally, a tuple[Callable, str] may also be provided as a key, for cases where different unique entries must be included for the same Callable. For example, the following are equivalent:

init_spec[(lambda state: state.mechanics.effector, "first")]
init_spec['mechanics.effector#first']

Note that the hash symbol # is used to concatenate the usual string representation for the callable, with the paired string.

Performance

For typical initialization mappings (1-10 items) construction is on the order of 50x slower than OrderedDict. Access is about 2-20x slower, depending whether indexed by string or by callable.

However, we only need to do a single construction and a single access of init_spec per batch/evaluation, so performance shouldn't matter too much in practice: the added overhead is <50 us/batch, and a batch normally takes at least 20,000 us to train.