Skip to content

Loss functions¤

feedbax.loss.CompositeLoss (AbstractLoss) ¤

Incorporates multiple simple loss terms and their relative weights.

__init__ (terms: Mapping[str, AbstractLoss] | Sequence[AbstractLoss],weights: Optional[Mapping[str, float] | Sequence[float]] = None,label: str = '',user_labels: bool = True)
¤

Note

During construction the user may pass dictionaries and/or sequences of AbstractLoss instances (terms) and weights.

Any CompositeLoss instances in terms are flattened, and their simple terms incorporated directly into the new composite loss, with the weights of those simple terms multiplied by the weight given in weights for their parent composite term.

If a composite term has a user-specified label, that label will be prepended to the labels of its component terms, on flattening. If the flattened terms still do not have unique labels, they will be suffixed with the lowest integer that makes them unique.

Parameters:

Name Type Description Default
terms Mapping[str, AbstractLoss] | Sequence[AbstractLoss]

The sequence or mapping of loss terms to be included.

required
weights Optional[Mapping[str, float] | Sequence[float]]

A float PyTree of the same structure as terms, giving the scalar term weights. By default, all terms have equal weight.

None
label str

The label for the composite loss.

''
user_labels bool

If True, the keys in terms---if it is a mapping--- are used as term labels, instead of the label field of each term. This is useful because it may be convenient for the user to match up the structure of terms and weights in a PyTree such as a dict, which provides labels, yet continue to use the default labels.

True
__call__ (states: State,trial_specs: AbstractTaskTrialSpec) -> LossDict
¤

Evaluate, weight, and return all component terms.

Parameters:

Name Type Description Default
states State

Trajectories of system states for a set of trials.

required
trial_specs AbstractTaskTrialSpec

Task specifications for the set of trials.

required

feedbax.loss.EffectorPositionLoss (AbstractLoss) ¤

Penalizes the effector's squared distance from the target position across the trial.

Attributes:

Name Type Description
label str

The label for the loss term.

discount_func Callable[[int], Float[Array, '#time']]

Returns a trajectory with which to weight (discount) the loss values calculated for each time step of the trial. Defaults to a power-law curve that puts most of the weight on time steps near the end of the trial.

Note

If the return value of discount_func is shaped such that it gives non-zero weight to the position error during the fixation period of (say) a delayed reach task, then typically the target will be specified as the fixation point during that period, and EffectorPositionLoss will also act as a fixation loss.

On the other hand, when using certain kinds of goal error discounting (e.g. exponential, favouring errors near the end of the trial) then the fixation loss may not be weighed into EffectorPositionLoss, and it may be appropriate to add EffectorFixationLoss to the composite loss. However, in that case the same result could still be achieved using a single instance of EffectorPositionLoss, by passing a discount that's the sum of the goal error discount (say, non-zero only near the end of the trial) and the hold signal (non-zero only during the fixation period) scaled by the relative weights of the goal and fixation error losses.

__call__ (states: PyTree,trial_specs: AbstractTaskTrialSpec) -> LossDict
¤

Inherited from feedbax.loss.AbstractLoss

feedbax.loss.EffectorStraightPathLoss (AbstractLoss) ¤

Penalizes non-straight paths followed by the effector between initial and final position.

Calculates the length of the paths followed, and normalizes by the Euclidean (straight-line) distance between the initial and final state.

Attributes:

Name Type Description
label str

The label for the loss term.

normalize_by Literal['actual', 'goal']

Controls whether to normalize by the distance between the initial position & actual final position, or the initial position & task-specified goal position.

__call__ (states: PyTree,trial_specs: AbstractTaskTrialSpec) -> LossDict
¤

Inherited from feedbax.loss.AbstractLoss

feedbax.loss.EffectorFixationLoss (AbstractLoss) ¤

Penalizes the effector's squared distance from the fixation position.

Similar to EffectorPositionLoss, but only penalizes the position error during the part of the trial where trial_specs.inputs.hold is non-zero/True.

Attributes:

Name Type Description
label str

The label for the loss term.

__call__ (states: PyTree,trial_specs: AbstractTaskTrialSpec) -> LossDict
¤

Inherited from feedbax.loss.AbstractLoss

feedbax.loss.EffectorFinalVelocityLoss (AbstractLoss) ¤

Penalizes the squared difference between the effector's final velocity and the goal velocity (typically zero) on the final timestep.

Attributes:

Name Type Description
label str

The label for the loss term.

__call__ (states: PyTree,trial_specs: AbstractTaskTrialSpec) -> LossDict
¤

Inherited from feedbax.loss.AbstractLoss

feedbax.loss.NetworkOutputLoss (AbstractLoss) ¤

Penalizes the squared values of the network's outputs.

Attributes:

Name Type Description
label str

The label for the loss term.

__call__ (states: PyTree,trial_specs: AbstractTaskTrialSpec) -> LossDict
¤

Inherited from feedbax.loss.AbstractLoss

feedbax.loss.NetworkActivityLoss (AbstractLoss) ¤

Penalizes the squared values of the network's hidden activity.

Attributes:

Name Type Description
label str

The label for the loss term.

__call__ (states: PyTree,trial_specs: AbstractTaskTrialSpec) -> LossDict
¤

Inherited from feedbax.loss.AbstractLoss

feedbax.loss.power_discount (n_steps: int,discount_exp: int = 6) -> Array
¤

A power-law vector that puts most of the weight on its later elements.

Parameters:

Name Type Description Default
n_steps int

The number of time steps in the trajectory to be weighted.

required
discount_exp int

The exponent of the power law.

6

Abstract base classes¤

feedbax.loss.AbstractLoss (Module) ¤

Abstract base class for loss functions.

Instances can be composed by addition and scalar multiplication.

term (states: PyTree,trial_specs: AbstractTaskTrialSpec) -> Array
abstractmethod ¤

Implement this to calculate a loss term.