Designing a task¤
To design a custom task, we need to subclass AbstractTask
, and implement the methods get_train_trial
and get_validation_trials
, which as their names suggest, define how to construct the task's training and validation trials.
Each of these methods must return an instance of some type of AbstractTaskTrialSpec
, which is a PyTree of trial specifications, or the parameters that define trials.
What's in a trial specification?¤
For example, consider SimpleReachTrialSpec
, which is the kind of trial specification you might use if you're building a variant of a simple reaching task.
class SimpleReachTrialSpec(AbstractReachTrialSpec):
"""Trial specification for a simple reaching task."""
inits: WhereDict
inputs: SimpleReachTaskInputs
target: CartesianState
intervene: dict[str, Array] = field(default_factory=dict)
What do all these fields mean?
inits
: Initial states¤
Provide the states used to initialize the model as a WhereDict
. This is similar to a normal dict
, but it lets us use certain lambda
functions as keys.
Here's a sketch of how we might define this part of a trial spec, if we wanted to start our trials with the mechanical effector at a certain (Cartesian) position and velocity:
init = WhereDict({
(lambda state: state.mechanics.effector): CartesianState(
pos=jnp.array(...), # Fill this in with actual position data for the trial(s).
vel=jnp.array(...),
)
})
We need to make sure to pass the right kind of data, depending on the part of the state we want to initialize. It helps to examine the state PyTree. In this case, our model state is a SimpleFeedbackState
, which has a field mechanics: MechanicsState
, which in turn has a field effector: CartesianState
, and that's the part we want to initialize. So we have to supply a CartesianState
.
We could have initialized only the position of the effector, like so:
init = WhereDict({
(lambda state: state.mechanics.effector.pos): jnp.array(...)
})
The nice thing about specifying our model initializations this way, is that we can provide data for any part of the model state using an appropriate lambda
.
inputs
: Model inputs¤
The model typically needs information about each trial of the task it is supposed to complete. The structure of this information can vary depending on the class of tasks. Actually, this is the part of the trial specification which can vary the most. We specify it with yet another PyTree.
In the case of SimpleReachTrialSpec
, the type of inputs
is SimpleReachTaskInputs
.
class SimpleReachTaskInputs(AbstractTaskInput):
"""Model input for a simple reaching task.
Attributes:
effector_target: The trajectory of effector target states to be presented to the model.
"""
effector_target: CartesianState
For simple reaches, the only model input we provide is the target that we need to reach to on the current trial.
target
: Target states¤
In the case of simple reaches, the same target state that is provided as input to the model is also used by the loss function to score the model's reach performance. This is straightforward enough: it's just the trajectory we want the effector to follow, in Cartesian coordinates.
See the example on loss functions for more information on how this information is actually used in scoring.
intervene
: Intervention parameters¤
This holds onto the trial-by-trial parameters for each intervention we've scheduled with [schedule_intervenors
][/feedbax/examples/3_intervening#scheduling-a-force-field]. It's unnecessary to construct this yourself, when writing a subclass of AbstractTask
.
However, if you write your own subclass of AbstractTaskTrialSpec
, you should make sure to include the field intervene: dict[str, Array] = field(default_factory=dict)
.
Building a trial specification¤
Put together what we've just seen, and a function that constructs a trial specification for a simple reaching task will look something like:
import jax.numpy as jnp
from feedbax.task import SimpleReachTaskInputs, WhereDict
def get_simple_reach_trial(key):
init_effector_pos = jnp.array(...)
inits = WhereDict({
(lambda state: state.mechanics.effector.pos): init_effector_pos,
})
effector_target = CartesianState(
pos=jnp.array(...),
vel=jnp.array(...),
)
return SimpleReachTrialSpec(
inits=inits,
# The target state of the effector is the info both 1) that the model needs
# to complete the task:
inputs=SimpleReachTaskInputs(
effector_target=effector_target,
),
# and 2) that the loss function needs to score the task:
target=effector_target,
# It's unnecessary to specify intervene, here.
)
Writing a subclass of AbstractTask
¤
Again, we have to write two functions like the one we just sketched, which return a SimpleReachTrialSpec
: get_train_trial
and get_validation_trials
. We also have to implement:
- one property,
n_validation_trials
. - three fields:
loss_func
,n_steps
, andseed_validation
.
Note that get_train_trial
should return a trial spec for a single training trial, while get_validation_trials
should return a trial spec for all the validation trials at once. This is reflected in the shape of the arrays that we use to build the trial spec: in get_validation_trials
, there is an extra dimension whose size is the number of trials in the validation set.
from feedbax.intervene import AbstractIntervenorInput
from feedbax.loss import AbstractLoss
from feedbax.task import AbstractTask, SimpleReachTaskInputs, SimpleReachTrialSpec
class MySimpleReachTask(AbstractTask):
loss_func: AbstractLoss
n_steps: int # The number of
seed_validation: int
intervention_specs: dict[AbstractIntervenorInput]
intervention_specs_validation: dict[AbstractIntervenorInput]
def get_train_trial(self, key) -> SimpleReachTrialSpec:
"""Return a single training trial specification.
Arguments:
key: A random key for generating the trial.
"""
inits = ...
effector_target = ...
return SimpleReachTrialSpec(
inits=inits,
inputs=SimpleReachTaskInputs(
effector_target=effector_target,
),
target=effector_target,
)
def get_validation_trials(self, key) -> SimpleReachTrialSpec:
"""Return a set of validation trials.
Arguments:
key: A random key for generating the validation set.
"""
inits = ...
effector_target = ...
return SimpleReachTrialSpec(
inits=inits,
inputs=SimpleReachTaskInputs(
effector_target=effector_target,
),
target=effector_target,
)
# We also need to implement this property.
@property
def n_validation_trials(self) -> int:
...
Defining an entirely new class of tasks¤
We can design tasks that are not part of an existing class of tasks. In that case, we need to define our own subclass of AbstractTaskTrialSpec
which is the PyTree of information defining a task trial. In particular,