Plotting¤
feedbax.plot.loss_history
(
train_history
: TaskTrainerHistory
,
xscale
: str = 'log'
,
yscale
: str = 'log'
,
cmap_name
: str = 'Set1'
)
->
Tuple[Figure, Axes]
¤
feedbax.plot.loss_history
(
train_history
: TaskTrainerHistory
,
xscale
: str = 'log'
,
yscale
: str = 'log'
,
cmap_name
: str = 'Set1'
)
->
Tuple[Figure, Axes]
Line plot of loss terms and their total over a training run.
Note
Each term in train_history.loss
is an array where the first dimension is the
training iteration, with an optional second batch dimension, e.g. for model
replicates.
Each term is plotted in a different color. If a batch dimension is present, multiple curves will be plotted for each term, all in the same color.
Labels for training iteration labels start at 1, so that the first iteration is visible when the x-axis is log-scaled.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_history |
TaskTrainerHistory
|
The training history object returned by a call to a
|
required |
xscale |
str
|
The scale of the x-axis. |
'log'
|
yscale |
str
|
The scale of the y-axis. |
'log'
|
cmap_name |
str
|
The name of the Matplotlib colormap to use for line colors. |
'Set1'
|
feedbax.plot.loss_mean_history
(
train_history
: TaskTrainerHistory
,
xscale
: str = 'log'
,
yscale
: str = 'log'
,
cmap
: str = 'Set1'
,
errorbar
: str | tuple[str, int] = 'sd'
)
->
Tuple[Figure, Axes]
¤
feedbax.plot.loss_mean_history
(
train_history
: TaskTrainerHistory
,
xscale
: str = 'log'
,
yscale
: str = 'log'
,
cmap
: str = 'Set1'
,
errorbar
: str | tuple[str, int] = 'sd'
)
->
Tuple[Figure, Axes]
Line plot of the means and standard deviations of loss terms and their total, over a training run of a batch of multiple models.
To plot separate curves for each member of the batch, use loss_history
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_history |
TaskTrainerHistory
|
The training history object returned by a call to a
|
required |
xscale |
str
|
The scale of the x-axis. |
'log'
|
yscale |
str
|
The scale of the y-axis. |
'log'
|
cmap |
str
|
The name of the Matplotlib colormap to use for line colors. |
'Set1'
|
feedbax.plot.effector_trajectories
(
states
: SimpleFeedbackState | PyTree[Float[Array, 'trial time ...'] | Any]
,
where_data
: Optional[Callable] = None
,
step
: int = 1
,
trial_specs
: Optional[AbstractReachTrialSpec] = None
,
endpoints
: Optional[Tuple[Float[Array, 'trial xy=2'], Float[Array, 'trial xy=2']]] = None
,
straight_guides
: bool = False
,
workspace
: Optional[Float[Array, 'bounds=2 xy=2']] = None
,
cmap_name
: Optional[str] = None
,
colors
: Optional[Sequence[str | Tuple[float, ...]]] = None
,
color
: Optional[str | Tuple[float, ...]] = None
,
ms
: int = 3
,
ms_source
: int = 6
,
ms_target
: int = 7
,
control_labels
: Optional[Tuple[str, str, str]] = None
,
control_label_type
: str = 'linear'
)
->
Tuple[Figure, Axes]
¤
feedbax.plot.effector_trajectories
(
states
: SimpleFeedbackState | PyTree[Float[Array, 'trial time ...'] | Any]
,
where_data
: Optional[Callable] = None
,
step
: int = 1
,
trial_specs
: Optional[AbstractReachTrialSpec] = None
,
endpoints
: Optional[Tuple[Float[Array, 'trial xy=2'], Float[Array, 'trial xy=2']]] = None
,
straight_guides
: bool = False
,
workspace
: Optional[Float[Array, 'bounds=2 xy=2']] = None
,
cmap_name
: Optional[str] = None
,
colors
: Optional[Sequence[str | Tuple[float, ...]]] = None
,
color
: Optional[str | Tuple[float, ...]] = None
,
ms
: int = 3
,
ms_source
: int = 6
,
ms_target
: int = 7
,
control_labels
: Optional[Tuple[str, str, str]] = None
,
control_label_type
: str = 'linear'
)
->
Tuple[Figure, Axes]
Plot trajectories of position, velocity, network output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
states |
SimpleFeedbackState | PyTree[Float[Array, 'trial time ...'] | Any]
|
A model state or PyTree of arrays from which the variables to be plotted can be extracted. |
required |
where_data |
Optional[Callable]
|
If |
None
|
step |
int
|
Plot every |
1
|
trial_specs |
Optional[AbstractReachTrialSpec]
|
The specifications for the trials being plotted. If supplied, this is used to plot markers at the initial and goal positions. |
None
|
endpoints |
Optional[Tuple[Float[Array, 'trial xy=2'], Float[Array, 'trial xy=2']]]
|
The initial and goal positions for the trials being plotted.
Overrides |
None
|
straight_guides |
bool
|
If this is |
False
|
workspace |
Optional[Float[Array, 'bounds=2 xy=2']]
|
The workspace bounds. If provided, the bounds are drawn as a rectangle. |
None
|
cmap_name |
Optional[str]
|
The name of the Matplotlib colormap to use across trials. |
None
|
colors |
Optional[Sequence[str | Tuple[float, ...]]]
|
A sequence of colors, one for each plotted trial. Overrides |
None
|
color |
Optional[str | Tuple[float, ...]]
|
A single color to use for all trials. Overrides |
None
|
ms |
int
|
Marker size for plots of states (trajectories). |
3
|
ms_source |
int
|
Marker size for the initial position, if |
6
|
ms_target |
int
|
Marker size for the goal position. |
7
|
control_label_type |
str
|
If |
'linear'
|
control_labels |
Optional[Tuple[str, str, str]]
|
A tuple giving the labels for the title, x-axis, and y-axis
of the final (controller output/force) plot. Overrides |
None
|
feedbax.plot.reach_endpoint_dists
(
trial_specs
: AbstractReachTrialSpec
,
s
: int = 7
,
color
: Optional[str] = None
,
**kwargs
)
->
Tuple[Figure, Sequence[Axes]]
¤
feedbax.plot.reach_endpoint_dists
(
trial_specs
: AbstractReachTrialSpec
,
s
: int = 7
,
color
: Optional[str] = None
,
**kwargs
)
->
Tuple[Figure, Sequence[Axes]]
Plot initial and goal positions along with their distributions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trial_specs |
AbstractReachTrialSpec
|
The specifications for the reach trials. |
required |
s |
int
|
Marker size for the initial and goal positions for all trials. |
7
|
color |
Optional[str]
|
The color to use for all points. If |
None
|
feedbax.plot.activity_sample_units
(
activities
: Float[Array, '*trial time unit']
,
n_samples
: int
,
unit_includes
: Optional[Sequence[int]] = None
,
cols
: int = 2
,
cmap_name
: str = 'tab10'
,
figsize
: tuple[float, float] = (6.4, 4.8)
,
*,
key
: PRNGKeyArray
)
->
Tuple[Figure, Axes]
¤
feedbax.plot.activity_sample_units
(
activities
: Float[Array, '*trial time unit']
,
n_samples
: int
,
unit_includes
: Optional[Sequence[int]] = None
,
cols
: int = 2
,
cmap_name
: str = 'tab10'
,
figsize
: tuple[float, float] = (6.4, 4.8)
,
*,
key
: PRNGKeyArray
)
->
Tuple[Figure, Axes]
Plot activity over multiple trials for a random sample of network units.
The result is a figure with n_samples + len(unit_includes)
subplots, arranged
in cols
columns.
When this function is called more than once in the course of an analysis, if the
same key
is passed and the network layer has the same number of units—that
is, the last dimension of activities
has the same size—then the same subset of
units will be sampled.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activities |
Float[Array, '*trial time unit']
|
The array of trial-by-trial activity over time for each unit in a network layer. |
required |
n_samples |
int
|
The number of units to sample from the layer. Along with |
required |
unit_includes |
Optional[Sequence[int]]
|
Indices of specific units to include in the plot, in addition to
the |
None
|
cols |
int
|
The number of columns in which to arrange the subplots. |
2
|
cmap_name |
str
|
The name of the Matplotlib colormap to use. Each trial will be plotted in a different color. |
'tab10'
|
figsize |
tuple[float, float]
|
The size of the figure. |
(6.4, 4.8)
|
key |
PRNGKeyArray
|
A random key used to sample the units to plot. |
required |
feedbax.plot.activity_heatmap
(
activity
: Float[Array, 'time unit']
,
cmap
: str = 'viridis'
)
¤
feedbax.plot.activity_heatmap
(
activity
: Float[Array, 'time unit']
,
cmap
: str = 'viridis'
)
Plot activity of all units in a network layer over time, on a single trial.
Note
This is a helper for imshow
,
when the data is an array of neural network unit activities with shape
(time, unit)
.
Example
When working with a SimpleFeedback
model built with a SimpleStagedNetwork
controller—for example, if we've constructed our model
using
point_mass_nn
—we can plot the activity
of the hidden layer of the network:
from feedbax import tree_take
states = task.eval(model, key=key_eval) # States for all validation trials.
states_trial0 = tree_take(states, 0)
activity_heatmap(states_trial0.net.hidden)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activity |
Float[Array, 'time unit']
|
The array of activity over time for each unit in a network layer. |
required |
cmap |
str
|
The name of the Matplotlib colormap to use. |
'viridis'
|
feedbax.plot.joint_pos_trajectory
(
xy
: Float[Array, 'time links ndim=2']
,
cmap_name
: str = 'viridis'
,
length_unit
: Optional[str] = None
,
ax
: Optional[Axes] = None
,
add_root
: bool = True
,
colorbar
: bool = True
,
ms_trace
: int = 6
,
lw_arm
: int = 4
,
workspace
: Optional[Float[Array, 'bounds=2 xy=2']] = None
)
->
tuple[Figure, Axes]
¤
feedbax.plot.joint_pos_trajectory
(
xy
: Float[Array, 'time links ndim=2']
,
cmap_name
: str = 'viridis'
,
length_unit
: Optional[str] = None
,
ax
: Optional[Axes] = None
,
add_root
: bool = True
,
colorbar
: bool = True
,
ms_trace
: int = 6
,
lw_arm
: int = 4
,
workspace
: Optional[Float[Array, 'bounds=2 xy=2']] = None
)
->
tuple[Figure, Axes]
Plot joint position for an \(n\)-link arm over time.
Plots the full arm segments at the beginning, middle, and end of the trial, along with joint traces for all time steps.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xy |
Float[Array, 'time links ndim=2']
|
The joint positions over time. |
required |
cmap_name |
str
|
The name of the Matplotlib colormap to use across time. |
'viridis'
|
length_unit |
Optional[str]
|
The length unit to display on the axes. By default, axes are unlabeled. |
None
|
ax |
Optional[Axes]
|
The Matplotlib axes to plot on. If |
None
|
add_root |
bool
|
Whether to add a root joint to |
True
|
colorbar |
bool
|
If |
True
|
ms_trace |
int
|
Marker size for the joint position traces over time. |
6
|
lw_arm |
int
|
Line width of the arm segments. |
4
|
workspace |
Optional[Float[Array, 'bounds=2 xy=2']]
|
The workspace bounds. If provided, the bounds are drawn as a rectangle. |
None
|