Tasks¤
Feedbax tasks are objects that group together:
- A loss function that is used to evaluated performance on a task;
- Per-trial data to:
- Initialize the state of a model prior to evaluation on task trials;
- Specify the parameters of task trials to the model and to the loss function.
Reaching¤
Simple reaching¤
feedbax.task.SimpleReachTaskInputs
(Module)
¤
Model input for a simple reaching task.
Attributes:
Name | Type | Description |
---|---|---|
effector_target |
CartesianState
|
The trajectory of effector target states to be presented to the model. |
feedbax.task.SimpleReachTrialSpec
(AbstractReachTrialSpec)
¤
Trial specification for a simple reaching task.
Attributes:
Name | Type | Description |
---|---|---|
inits |
InitsWhereDict
|
A mapping from |
inputs |
SimpleReachTaskInputs
|
For providing the model with the reach target. |
target |
CartesianState
|
The target trajectory for the mechanical end effector. |
intervene |
Mapping[IntervenorLabelStr, AbstractIntervenorInput]
|
A mapping from unique intervenor names, to per-trial intervention parameters. |
feedbax.task.SimpleReaches
(AbstractTask)
¤
Reaches between random endpoints in a rectangular workspace. No hold signal.
Validation set is center-out reaches.
Note
This passes a trajectory of target velocities all equal to zero, assuming that the user will choose a loss function that penalizes only the initial or final velocities. If the loss function penalizes the intervening velocities, this task no longer makes sense as a reaching task.
Attributes:
Name | Type | Description |
---|---|---|
n_steps |
int
|
The number of time steps in each task trial. |
loss_func |
AbstractLoss
|
The loss function that grades performance on each trial. |
workspace |
Float[Array, 'bounds=2 ndim=2']
|
The rectangular workspace in which the reaches are distributed. |
seed_validation |
int
|
The random seed for generating the validation trials. |
intervention_specs |
TaskInterventionSpecs
|
A mapping from unique intervenor names, to specifications for generating per-trial intervention parameters on training trials. |
intervention_specs_validation |
TaskInterventionSpecs
|
A mapping from unique intervenor names, to specifications for generating per-trial intervention parameters on validation trials. |
eval_grid_n |
int
|
The number of evenly-spaced internal grid points of the workspace at which a set of center-out reach is placed. |
eval_n_directions |
int
|
The number of evenly-spread center-out reaches
starting from each workspace grid point in the validation set. The number
of trials in the validation set is equal to
|
eval_reach_length |
float
|
The length (in space) of each reach in the validation set. |
n_validation_trials: int
property
¤
Number of trials in the validation set.
get_train_trial_with_intervenor_params
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
¤
get_train_trial_with_intervenor_params
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
Inherited from feedbax.task.AbstractTask
eval_trials
(
model
: AbstractModel[StateT]
,
trial_specs
: AbstractTaskTrialSpec
,
keys
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
¤
eval_trials
(
model
: AbstractModel[StateT]
,
trial_specs
: AbstractTaskTrialSpec
,
keys
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
Inherited from feedbax.task.AbstractTask
eval_with_loss
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
¤
eval_with_loss
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
Inherited from feedbax.task.AbstractTask
eval
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
StateT
¤
eval
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax.task.AbstractTask
eval_ensemble_with_loss
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
tuple[StateT, LossDict]
¤
eval_ensemble_with_loss
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
tuple[StateT, LossDict]
Inherited from feedbax.task.AbstractTask
eval_ensemble
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
StateT
¤
eval_ensemble
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
StateT
Inherited from feedbax.task.AbstractTask
eval_train_batch
(
model
: AbstractModel[StateT]
,
batch_size
: int
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤
eval_train_batch
(
model
: AbstractModel[StateT]
,
batch_size
: int
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
Inherited from feedbax.task.AbstractTask
eval_ensemble_train_batch
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
batch_size
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤
eval_ensemble_train_batch
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
batch_size
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
Inherited from feedbax.task.AbstractTask
add_intervenors_to_base_model
(
model
: AbstractStagedModel[StateT]
)
->
AbstractStagedModel[StateT]
¤
add_intervenors_to_base_model
(
model
: AbstractStagedModel[StateT]
)
->
AbstractStagedModel[StateT]
Inherited from feedbax.task.AbstractTask
activate_interventions
(
labels
: NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']
,
labels_validation
: Optional[NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']] = None
,
validation_same_schedule
)
->
Self
¤
activate_interventions
(
labels
: NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']
,
labels_validation
: Optional[NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']] = None
,
validation_same_schedule
)
->
Self
Inherited from feedbax.task.AbstractTask
get_train_trial
(
key
: PRNGKeyArray
)
->
SimpleReachTrialSpec
¤
get_train_trial
(
key
: PRNGKeyArray
)
->
SimpleReachTrialSpec
Random reach endpoints across the rectangular workspace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray
|
A random key for generating the trial. |
required |
get_validation_trials
(
key
: PRNGKeyArray
)
->
SimpleReachTrialSpec
¤
get_validation_trials
(
key
: PRNGKeyArray
)
->
SimpleReachTrialSpec
Center-out reach sets in a grid across the rectangular workspace.
Delayed (cued) reaching¤
feedbax.task.DelayedReachTaskInputs
(Module)
¤
Model input for a delayed reaching task.
Attributes:
Name | Type | Description |
---|---|---|
effector_target |
CartesianState
|
The trajectory of effector target states to be presented to the model. |
hold |
Int[Array, 'time 1']
|
The hold/go (1/0 signal) to be presented to the model. |
target_on |
Int[Array, 'time 1']
|
A signal indicating to the model when the value of |
feedbax.task.DelayedReachTrialSpec
(AbstractReachTrialSpec)
¤
Trial specification for a delayed reaching task.
Attributes:
Name | Type | Description |
---|---|---|
inits |
WhereDict
|
A mapping from |
inputs |
DelayedReachTaskInputs
|
For providing the model with the reach target and hold signal. |
target |
CartesianState
|
The target trajectory for the mechanical end effector. |
epoch_start_idxs |
Int[Array, n_epochs]
|
The indices of the start of each epoch in the trial. |
intervene |
Mapping[IntervenorLabelStr, AbstractIntervenorInput]
|
A mapping from unique intervenor names, to per-trial intervention parameters. |
feedbax.task.DelayedReaches
(AbstractTask)
¤
Uniform random endpoints in a rectangular workspace.
e.g. allows for a stimulus epoch, followed by a delay period, then movement.
Attributes:
Name | Type | Description |
---|---|---|
loss_func |
AbstractLoss
|
The loss function that grades performance on each trial. |
workspace |
Float[Array, 'bounds=2 ndim=2']
|
The rectangular workspace in which the reaches are distributed. |
n_steps |
int
|
The number of time steps in each task trial. |
epoch_len_ranges |
Tuple[Tuple[int, int], ...]
|
The ranges from which to uniformly sample the durations of the task phases for each task trial. |
target_on_epochs |
Int[Array, _]
|
The epochs in which the "target on" signal is turned on. |
hold_epochs |
Int[Array, _]
|
The epochs in which the hold signal is turned on. |
eval_n_directions |
int
|
The number of evenly-spread center-out reaches
starting from each workspace grid point in the validation set. The number
of trials in the validation set is equal to
|
eval_reach_length |
float
|
The length (in space) of each reach in the validation set. |
eval_grid_n |
int
|
The number of evenly-spaced internal grid points of the workspace at which a set of center-out reach is placed. |
seed_validation |
int
|
The random seed for generating the validation trials. |
n_validation_trials: int
property
¤
Number of trials in the validation set.
get_train_trial_with_intervenor_params
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
¤
get_train_trial_with_intervenor_params
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
Inherited from feedbax.task.AbstractTask
eval_trials
(
model
: AbstractModel[StateT]
,
trial_specs
: AbstractTaskTrialSpec
,
keys
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
¤
eval_trials
(
model
: AbstractModel[StateT]
,
trial_specs
: AbstractTaskTrialSpec
,
keys
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
Inherited from feedbax.task.AbstractTask
eval_with_loss
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
¤
eval_with_loss
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
Inherited from feedbax.task.AbstractTask
eval
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
StateT
¤
eval
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
StateT
Inherited from feedbax.task.AbstractTask
eval_ensemble_with_loss
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
tuple[StateT, LossDict]
¤
eval_ensemble_with_loss
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
tuple[StateT, LossDict]
Inherited from feedbax.task.AbstractTask
eval_ensemble
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
StateT
¤
eval_ensemble
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
StateT
Inherited from feedbax.task.AbstractTask
eval_train_batch
(
model
: AbstractModel[StateT]
,
batch_size
: int
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤
eval_train_batch
(
model
: AbstractModel[StateT]
,
batch_size
: int
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
Inherited from feedbax.task.AbstractTask
eval_ensemble_train_batch
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
batch_size
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤
eval_ensemble_train_batch
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
batch_size
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
Inherited from feedbax.task.AbstractTask
add_intervenors_to_base_model
(
model
: AbstractStagedModel[StateT]
)
->
AbstractStagedModel[StateT]
¤
add_intervenors_to_base_model
(
model
: AbstractStagedModel[StateT]
)
->
AbstractStagedModel[StateT]
Inherited from feedbax.task.AbstractTask
activate_interventions
(
labels
: NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']
,
labels_validation
: Optional[NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']] = None
,
validation_same_schedule
)
->
Self
¤
activate_interventions
(
labels
: NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']
,
labels_validation
: Optional[NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']] = None
,
validation_same_schedule
)
->
Self
Inherited from feedbax.task.AbstractTask
validation_plots
(
states
,
trial_specs
: Optional[AbstractTaskTrialSpec] = None
)
->
Mapping[str, go.Figure]
¤
validation_plots
(
states
,
trial_specs
: Optional[AbstractTaskTrialSpec] = None
)
->
Mapping[str, go.Figure]
Inherited from feedbax.task.AbstractTask
get_train_trial
(
key
: PRNGKeyArray
)
->
DelayedReachTrialSpec
¤
get_train_trial
(
key
: PRNGKeyArray
)
->
DelayedReachTrialSpec
Random reach endpoints across the rectangular workspace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray
|
A random key for generating the trial. |
required |
get_validation_trials
(
key
: PRNGKeyArray
)
->
DelayedReachTrialSpec
¤
get_validation_trials
(
key
: PRNGKeyArray
)
->
DelayedReachTrialSpec
Center-out reach sets in a grid across the rectangular workspace.
Abstract base classes¤
feedbax.task.AbstractTaskTrialSpec
(Module)
¤
Abstract base class for trial specifications provided by a task.
Attributes:
Name | Type | Description |
---|---|---|
inits |
AbstractVar[InitsWhereDict]
|
A mapping from |
inputs |
AbstractVar[PyTree]
|
A PyTree of inputs to the model. |
target |
AbstractVar[PyTree[Array]]
|
A PyTree of target states. |
intervene |
AbstractVar[Mapping[IntervenorLabelStr, AbstractIntervenorInput]]
|
A mapping from unique intervenor names, to per-trial intervention parameters. |
feedbax.task.AbstractTask
(Module)
¤
Abstract base class for tasks.
Provides methods for evaluating suitable models or ensembles of models on training and validation trials.
Subclasses must provide:
- a method that generates training trials
- a property that provides a set of validation trials
- a field for a loss function that grades performance on the task
Attributes:
Name | Type | Description |
---|---|---|
loss_func |
AbstractVar[AbstractLoss]
|
The loss function that grades task performance. |
n_steps |
AbstractVar[int]
|
The number of time steps in the task trials. |
seed_validation |
AbstractVar[int]
|
The random seed for generating the validation trials. |
intervention_specs |
AbstractVar[TaskInterventionSpecs]
|
Mappings from unique intervenor names, to specifications for generating per-trial intervention parameters. Distinct fields provide mappings for training and validation trials, though the two may be identical depending on scheduling. |
validation_trials: AbstractTaskTrialSpec
cached
property
¤
The set of validation trials associated with the task.
get_train_trial
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
abstractmethod
¤
get_train_trial
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
Return a single training trial specification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray
|
A random key for generating the trial. |
required |
get_train_trial_with_intervenor_params
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
¤
get_train_trial_with_intervenor_params
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
Return a single training trial specification, including intervention parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray
|
A random key for generating the trial. |
required |
get_validation_trials
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
abstractmethod
¤
get_validation_trials
(
key
: PRNGKeyArray
)
->
AbstractTaskTrialSpec
Return a set of validation trials, given a random key.
Subclasses must override this method. However, the validation
used during training and provided by self.validation_set
will be determined by the field self.seed_validation
, which must
also be implemented by subclasses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray
|
A random key for generating the validation set. |
required |
n_validation_trials
(
)
->
int
¤
n_validation_trials
(
)
->
int
Number of trials in the validation set.
eval_trials
(
model
: AbstractModel[StateT]
,
trial_specs
: AbstractTaskTrialSpec
,
keys
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
¤
eval_trials
(
model
: AbstractModel[StateT]
,
trial_specs
: AbstractTaskTrialSpec
,
keys
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
Evaluate a model on a set of trials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
AbstractModel[StateT]
|
The model to evaluate. |
required |
trial_specs |
AbstractTaskTrialSpec
|
The set of trials to evaluate the model on. |
required |
keys |
PRNGKeyArray
|
For providing randomness during model evaluation. |
required |
eval_with_loss
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
¤
eval_with_loss
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict]
Evaluate a model on the task's validation set of trials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
AbstractModel[StateT]
|
The model to evaluate. |
required |
key |
PRNGKeyArray
|
For providing randomness during model evaluation. |
required |
Returns:
Type | Description |
---|---|
StateT
|
The losses for the trials in the validation set. |
LossDict
|
The evaluated model states. |
eval
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
StateT
¤
eval
(
model
: AbstractModel[StateT]
,
key
: PRNGKeyArray
)
->
StateT
Return states for a model evaluated on the tasks's set of validation trials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
AbstractModel[StateT]
|
The model to evaluate. |
required |
key |
PRNGKeyArray
|
For providing randomness during model evaluation. |
required |
eval_ensemble_with_loss
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
tuple[StateT, LossDict]
¤
eval_ensemble_with_loss
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
tuple[StateT, LossDict]
Return states and losses for an ensemble of models evaluated on the tasks's set of validation trials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
models |
AbstractModel[StateT]
|
The ensemble of models to evaluate. |
required |
n_replicates |
int
|
The number of models in the ensemble. |
required |
key |
PRNGKeyArray
|
For providing randomness during model evaluation.
Will be split into |
required |
ensemble_random_trials |
bool
|
If |
True
|
eval_ensemble
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
StateT
¤
eval_ensemble
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
StateT
Return states for an ensemble of models evaluated on the tasks's set of validation trials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
models |
AbstractModel[StateT]
|
The ensemble of models to evaluate. |
required |
n_replicates |
int
|
The number of models in the ensemble. |
required |
key |
PRNGKeyArray
|
For providing randomness during model evaluation.
Will be split into |
required |
ensemble_random_trials |
bool
|
If |
True
|
eval_train_batch
(
model
: AbstractModel[StateT]
,
batch_size
: int
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤
eval_train_batch
(
model
: AbstractModel[StateT]
,
batch_size
: int
,
key
: PRNGKeyArray
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
Evaluate a model on a single batch of training trials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
AbstractModel[StateT]
|
The model to evaluate. |
required |
batch_size |
int
|
The number of trials in the batch. |
required |
key |
PRNGKeyArray
|
For providing randomness during model evaluation. |
required |
Returns:
Type | Description |
---|---|
StateT
|
The losses for the trials in the batch. |
LossDict
|
The evaluated model states. |
AbstractTaskTrialSpec
|
The trial specifications for the batch. |
eval_ensemble_train_batch
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
batch_size
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
¤
eval_ensemble_train_batch
(
models
: AbstractModel[StateT]
,
n_replicates
: int
,
batch_size
: int
,
key
: PRNGKeyArray
,
ensemble_random_trials
: bool = True
)
->
Tuple[StateT, LossDict, AbstractTaskTrialSpec]
Evaluate an ensemble of models on a single training batch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
models |
AbstractModel[StateT]
|
The ensemble of models to evaluate. |
required |
n_replicates |
int
|
The number of models in the ensemble. |
required |
batch_size |
int
|
The number of trials in the batch to evaluate. |
required |
key |
PRNGKeyArray
|
For providing randomness during model evaluation. |
required |
ensemble_random_trials |
bool
|
If |
True
|
Returns:
Type | Description |
---|---|
StateT
|
The losses for the trials in the batch, for each model in the ensemble. |
LossDict
|
The evaluated model states, for each trial and each model in the ensemble. |
AbstractTaskTrialSpec
|
The trial specifications for the batch. |
add_intervenors_to_base_model
(
model
: AbstractStagedModel[StateT]
)
->
AbstractStagedModel[StateT]
¤
add_intervenors_to_base_model
(
model
: AbstractStagedModel[StateT]
)
->
AbstractStagedModel[StateT]
Add the task's scheduled intervenors to a model.
Assumes that the model has the appropriate structure to admit the
intervention. This depends on the the where
and stage_name
properties stored in the task's intervention_specs
field, since the
original call to schedule_intervenor
that added them to the task.
where
should pick out anAbstractStagedModel
component ofmodel
.- If defined,
stage_name
should be the name of one of the stages of theAbstractStagedModel
component picked out bywhere
.
Any existing intervenors in the model that were scheduled with another task, will be removed to prevent conflicts. Other intervenors which were added directly to the model without being scheduled with a task will not be removed.
Note
This method is mostly useful when evaluating a trained model on a task with a different set of interventions than the one it was trained on.
activate_interventions
(
labels
: NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']
,
labels_validation
: Optional[NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']] = None
,
validation_same_schedule
)
->
Self
¤
activate_interventions
(
labels
: NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']
,
labels_validation
: Optional[NonCharSequence[IntervenorLabelStr] | Literal['all', 'none']] = None
,
validation_same_schedule
)
->
Self
Return a task where scheduling if active only for the interventions with the given labels.
validation_plots
(
states
,
trial_specs
: Optional[AbstractTaskTrialSpec] = None
)
->
Mapping[str, go.Figure]
abstractmethod
¤
validation_plots
(
states
,
trial_specs
: Optional[AbstractTaskTrialSpec] = None
)
->
Mapping[str, go.Figure]
Returns a basic set of plots to visualize performance on the task.
Useful functions for building tasks¤
feedbax.task.internal_grid_points
(
bounds
: Float[Array, 'bounds=2 ndim=2']
,
n
: int = 2
)
->
Float[Array, 'n**ndim ndim=2']
¤
feedbax.task.internal_grid_points
(
bounds
: Float[Array, 'bounds=2 ndim=2']
,
n
: int = 2
)
->
Float[Array, 'n**ndim ndim=2']
Return a list of evenly-spaced grid points internal to the bounds.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
bounds |
Float[Array, 'bounds=2 ndim=2']
|
The outer bounds of the grid. |
required |
n |
int
|
The number of internal grid points along each dimension. |
2
|
Example
internal_grid_points(
bounds=((0, 0), (9, 9)),
n=2,
)
>> Array([[3., 3.], [6., 3.], [3., 6.], [6., 6.]]).
Using lambda functions as dictionary keys¤
feedbax.task.WhereDict
¤
An OrderedDict
that allows limited use of where
lambdas as keys.
In particular, keys can be lambdas that take a single argument, and return a single (nested) attribute accessed from that argument.
Lambdas are parsed to equivalent strings, which can be used
interchangeably as keys. For example, the following are equivalent when
init_spec
is a WhereDict
:
init_spec[lambda state: state.mechanics.effector]
init_spec['mechanics.effector']
Finally, a tuple[Callable, str]
may also be provided as a key, for cases where
different unique entries must be included for the same Callable
. For example,
the following are equivalent:
init_spec[(lambda state: state.mechanics.effector, "first")]
init_spec['mechanics.effector#first']
Note that the hash symbol #
is used to concatenate the usual string representation
for the callable, with the paired string.
Performance
For typical initialization mappings (1-10 items) construction is on the order
of 50x slower than OrderedDict
. Access is about 2-20x slower, depending
whether indexed by string or by callable.
However, we only need to do a single construction and a single access of
init_spec
per batch/evaluation, so performance shouldn't matter too much in
practice: the added overhead is <50 us/batch, and a batch normally takes
at least 20,000 us to train.