Loss functions¤
In Feedbax, each task object—each instance of any type of AbstractTask
—must provide a loss function that is used to score performance on that task.
Background on loss functions
In optimal control, we measure how well a task is performed using a cost function. In machine learning, the same concept is called a loss function. The larger the value of the loss function, the worse the performance on the task.
Loss functions may be composed of multiple loss terms, which measure different aspects of performance. Typically, terms are evaluated separately then weighted and summed to give the total value of the loss.
A movement model generates trajectories (states) of positions, velocities, and forces over time. Loss terms are often quadratic functions which are calculated locally for each time step in a trajectory, then summed or averaged to give the overall value for the loss term across the entire trajectory, prior to its being added to the other loss terms.
For example, when controlling a point mass to reach from a starting position to a goal position, the position error loss for the entire reach reach might be evaluated like so: at each time step of the reach, take the difference between the current position of the point mass and the goal position, and square it; repeat this at every time step, and sum the results across all timesteps. This value will be larger, the longer and farther the point mass stays away from the goal position.
Some loss terms may only be calculated for a subset of time steps. For a reaching task, we might include a loss term that penalizes the square of the velocity but only on the final time step, because we want the point mass to stop at the goal position rather than simply pass through it. It would not make sense to apply this loss term at time steps in the middle of the reach, when the point mass ought to be moving at a non-zero velocity toward the goal!
Common loss terms, such as EffectorPositionLoss
and EffectorFinalVelocityLoss
, are defined in feedbax.loss
.
A loss function with multiple weighted terms can be defined algebraically:
from feedbax.loss import (
EffectorFinalVelocityLoss,
EffectorPositionLoss,
NetworkActivityLoss,
NetworkOutputLoss,
)
loss_func = (
1.0 * EffectorPositionLoss()
+ 1.0 * EffectorFinalVelocityLoss()
+ 1e-5 * NetworkOutputLoss()
+ 1e-5 * NetworkActivityLoss()
)
This is easy to read, and—visually—close to the mathematics.
Typically, after constructing a loss we'll pass it to a task object we're constructing. For example:
task = SimpleReaches(loss_func=loss_func, ...)
Feedbax is flexible about how losses are constructed. The following achieves the same result we did above.
from feedbax.loss import CompositeLoss
loss_func = CompositeLoss(
[
EffectorPositionLoss(),
EffectorFinalVelocityLoss(),
NetworkOutputLoss(),
NetworkActivityLoss(),
],
weights=[1.0, 1.0, 1e-5, 1e-5],
)
CompositeLoss
allows us to group multiple terms into a single loss function—actually, it's what's constructed in the background when we write the loss function in algebraic form, like we did previously.
The two approaches can be mixed together.
loss_func = CompositeLoss(
[
EffectorPositionLoss(),
EffectorFinalVelocityLoss(),
NetworkOutputLoss(),
],
weights=[1.0, 1.0, 1e-5],
)
# Add another term onto a composite loss we've already defined.
loss_func_plus = loss_func + 1e-5 * NetworkActivityLoss()
The term weights are often saved along with other model hyperparameters. In that case we can make our list of hyperparameters a little more readable by storing the weights in a dict, rather than a list.
weights = dict(
effector_position=1.0,
effector_final_velocity=1.0,
nn_output=1e-5,
nn_hidden=1e-5,
)
loss_func = CompositeLoss(
dict(
effector_position=EffectorPositionLoss(),
effector_final_velocity=EffectorFinalVelocityLoss(),
nn_output=NetworkOutputLoss(),
nn_hidden=NetworkActivityLoss(),
),
weights=weights,
)
LossDict
¤
When called, all AbstractLoss
instances return a LossDict
object.
To see what this looks like, let's instantiate a model and evaluate it, then calculate the loss.
import jax
from feedbax.xabdeef import point_mass_nn_simple_reaches
key_init, key_eval = jax.random.split(jax.random.PRNGKey(0))
context = point_mass_nn_simple_reaches(key=key_init)
task, model = context.task, context.model
# We won't bother to train the model for this example
states = task.eval(model, key=key_eval)
# the loss function requires both the model states, and the trial information
loss = task.loss_func(states, task.validation_trials)
We can also get losses and states simultaneously.
loss, states = task.eval_with_loss(model, key=key_eval)
A LossDict
is mostly similar to a regular dict
and contains values for all the loss terms.
loss
However, it also provides the total loss.
loss.total
At the end of a training run, TaskTrainer
returns loss history.
model, train_history = context.train(
n_batches=500,
batch_size=250,
log_step=125,
key=jax.random.PRNGKey(1),
)
A LossDict
is part of the TaskTrainerHistory
object.
import equinox as eqx
eqx.tree_pprint(train_history.loss)
Note that in this case, each loss term is a vector containing a history of the loss term's value across the training run.
from feedbax.plot import plot_loss_history
_ = plot_loss_history(train_history)
Writing a new loss term¤
We can write new loss terms by subclassing AbstractLoss
. To do so, we need to implement the label
field and the term
method.
The label
is just a string that will be used to identify the loss term, for example in the legend of a loss history plot.
As arguments, the term
method takes the state history and the trial specifications, and can calculate the loss from any part(s) of those. It returns the loss as a JAX array. Generally this array will contain just one dimension, corresponding to the batch of trials on which the loss is being evaluated—in other words, it returns a scalar for each trial in the batch. Note that it's generally unnecessary in term
to refer explicitly to the batch dimension.
Any other dimensions (such as time) should be eliminated before returning. Some loss terms will only calculate a loss for a single time step, whereas others will need to aggregate (for example, take the sum of) values calculated for multiple time steps.
The following is a mock example of a typical subclass of AbstractLoss
.
import jax.numpy as jnp
from jaxtyping import Array
from feedbax.loss import AbstractLoss
class SomethingLoss(AbstractLoss):
label: str = "effector_position"
def term(self, states, trial_specs) -> Array:
# Sum over length of variable vector
loss = jnp.sum(
(states.some_variable - trial_specs.some_target) ** 2,
axis=-1
)
# Sum over time (if calculated for multiple time steps)
loss = jnp.sum(loss, axis=-1)
return loss
Here, trial_specs.some_target
is some array of target values for the variable, provided by the task when it constructs the trial specifications.