Loss functions¤
feedbax.loss.CompositeLoss
(AbstractLoss)
¤
Incorporates multiple simple loss terms and their relative weights.
__init__
(
terms
: Mapping[str, AbstractLoss] | Sequence[AbstractLoss]
,
weights
: Optional[Mapping[str, float] | Sequence[float]] = None
,
label
: str = ''
,
user_labels
: bool = True
)
¤
__init__
(
terms
: Mapping[str, AbstractLoss] | Sequence[AbstractLoss]
,
weights
: Optional[Mapping[str, float] | Sequence[float]] = None
,
label
: str = ''
,
user_labels
: bool = True
)
Note
During construction the user may pass dictionaries and/or sequences
of AbstractLoss
instances (terms
) and weights.
Any CompositeLoss
instances in terms
are flattened, and their
simple terms incorporated directly into the new composite loss,
with the weights of those simple terms multiplied by the weight
given in weights
for their parent composite term.
If a composite term has a user-specified label, that label will be prepended to the labels of its component terms, on flattening. If the flattened terms still do not have unique labels, they will be suffixed with the lowest integer that makes them unique.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
terms |
Mapping[str, AbstractLoss] | Sequence[AbstractLoss]
|
The sequence or mapping of loss terms to be included. |
required |
weights |
Optional[Mapping[str, float] | Sequence[float]]
|
A float PyTree of the same structure as |
None
|
label |
str
|
The label for the composite loss. |
''
|
user_labels |
bool
|
If |
True
|
__call__
(
states
: State
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
¤
__call__
(
states
: State
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
Evaluate, weight, and return all component terms.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
states |
State
|
Trajectories of system states for a set of trials. |
required |
trial_specs |
AbstractTaskTrialSpec
|
Task specifications for the set of trials. |
required |
feedbax.loss.EffectorPositionLoss
(AbstractLoss)
¤
Penalizes the effector's squared distance from the target position across the trial.
Attributes:
Name | Type | Description |
---|---|---|
label |
str
|
The label for the loss term. |
discount_func |
Callable[[int], Float[Array, '#time']]
|
Returns a trajectory with which to weight (discount) the loss values calculated for each time step of the trial. Defaults to a power-law curve that puts most of the weight on time steps near the end of the trial. |
Note
If the return value of discount_func
is shaped such that it gives
non-zero weight to the position error during the fixation period of
(say) a delayed reach task, then typically the target will be specified
as the fixation point during that period, and EffectorPositionLoss
will also act as a fixation loss.
On the other hand, when using certain kinds of goal error discounting
(e.g. exponential, favouring errors near the end of the trial) then the
fixation loss may not be weighed into EffectorPositionLoss
, and it
may be appropriate to add EffectorFixationLoss
to the composite loss.
However, in that case the same result could still be achieved using a
single instance of EffectorPositionLoss
, by passing a discount
that's the sum of the goal error discount (say, non-zero only near the
end of the trial) and the hold signal (non-zero only during the
fixation period) scaled by the relative weights of the goal and
fixation error losses.
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
¤
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
Inherited from feedbax.loss.AbstractLoss
feedbax.loss.EffectorStraightPathLoss
(AbstractLoss)
¤
Penalizes non-straight paths followed by the effector between initial and final position.
Calculates the length of the paths followed, and normalizes by the Euclidean (straight-line) distance between the initial and final state.
Attributes:
Name | Type | Description |
---|---|---|
label |
str
|
The label for the loss term. |
normalize_by |
Literal['actual', 'goal']
|
Controls whether to normalize by the distance between the initial position & actual final position, or the initial position & task-specified goal position. |
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
¤
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
Inherited from feedbax.loss.AbstractLoss
feedbax.loss.EffectorFixationLoss
(AbstractLoss)
¤
Penalizes the effector's squared distance from the fixation position.
Similar to EffectorPositionLoss
, but only penalizes the position
error during the part of the trial where trial_specs.inputs.hold
is non-zero/True
.
Attributes:
Name | Type | Description |
---|---|---|
label |
str
|
The label for the loss term. |
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
¤
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
Inherited from feedbax.loss.AbstractLoss
feedbax.loss.EffectorFinalVelocityLoss
(AbstractLoss)
¤
Penalizes the squared difference between the effector's final velocity and the goal velocity (typically zero) on the final timestep.
Attributes:
Name | Type | Description |
---|---|---|
label |
str
|
The label for the loss term. |
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
¤
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
Inherited from feedbax.loss.AbstractLoss
feedbax.loss.NetworkOutputLoss
(AbstractLoss)
¤
Penalizes the squared values of the network's outputs.
Attributes:
Name | Type | Description |
---|---|---|
label |
str
|
The label for the loss term. |
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
¤
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
Inherited from feedbax.loss.AbstractLoss
feedbax.loss.NetworkActivityLoss
(AbstractLoss)
¤
Penalizes the squared values of the network's hidden activity.
Attributes:
Name | Type | Description |
---|---|---|
label |
str
|
The label for the loss term. |
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
¤
__call__
(
states
: PyTree
,
trial_specs
: AbstractTaskTrialSpec
)
->
LossDict
Inherited from feedbax.loss.AbstractLoss
feedbax.loss.power_discount
(
n_steps
: int
,
discount_exp
: int = 6
)
->
Array
¤
feedbax.loss.power_discount
(
n_steps
: int
,
discount_exp
: int = 6
)
->
Array
A power-law vector that puts most of the weight on its later elements.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_steps |
int
|
The number of time steps in the trajectory to be weighted. |
required |
discount_exp |
int
|
The exponent of the power law. |
6
|