Skip to content

Plotting¤

feedbax.plot.loss_history (train_history: TaskTrainerHistory,xscale: str = 'log',yscale: str = 'log',cmap_name: str = 'Set1') -> Tuple[Figure, Axes]
¤

Line plot of loss terms and their total over a training run.

Note

Each term in train_history.loss is an array where the first dimension is the training iteration, with an optional second batch dimension, e.g. for model replicates.

Each term is plotted in a different color. If a batch dimension is present, multiple curves will be plotted for each term, all in the same color.

Labels for training iteration labels start at 1, so that the first iteration is visible when the x-axis is log-scaled.

Parameters:

Name Type Description Default
train_history TaskTrainerHistory

The training history object returned by a call to a TaskTrainer. The function will specifically access train_history.loss.

required
xscale str

The scale of the x-axis.

'log'
yscale str

The scale of the y-axis.

'log'
cmap_name str

The name of the Matplotlib colormap to use for line colors.

'Set1'

feedbax.plot.loss_mean_history (train_history: TaskTrainerHistory,xscale: str = 'log',yscale: str = 'log',cmap: str = 'Set1',errorbar: str | tuple[str, int] = 'sd') -> Tuple[Figure, Axes]
¤

Line plot of the means and standard deviations of loss terms and their total, over a training run of a batch of multiple models.

To plot separate curves for each member of the batch, use loss_history.

Parameters:

Name Type Description Default
train_history TaskTrainerHistory

The training history object returned by a call to a TaskTrainer. The function will specifically access train_history.loss.

required
xscale str

The scale of the x-axis.

'log'
yscale str

The scale of the y-axis.

'log'
cmap str

The name of the Matplotlib colormap to use for line colors.

'Set1'

feedbax.plot.effector_trajectories (states: SimpleFeedbackState | PyTree[Float[Array, 'trial time ...'] | Any],where_data: Optional[Callable] = None,step: int = 1,trial_specs: Optional[AbstractReachTrialSpec] = None,endpoints: Optional[Tuple[Float[Array, 'trial xy=2'], Float[Array, 'trial xy=2']]] = None,straight_guides: bool = False,workspace: Optional[Float[Array, 'bounds=2 xy=2']] = None,cmap_name: Optional[str] = None,colors: Optional[Sequence[str | Tuple[float, ...]]] = None,color: Optional[str | Tuple[float, ...]] = None,ms: int = 3,ms_source: int = 6,ms_target: int = 7,control_labels: Optional[Tuple[str, str, str]] = None,control_label_type: str = 'linear') -> Tuple[Figure, Axes]
¤

Plot trajectories of position, velocity, network output.

Parameters:

Name Type Description Default
states SimpleFeedbackState | PyTree[Float[Array, 'trial time ...'] | Any]

A model state or PyTree of arrays from which the variables to be plotted can be extracted.

required
where_data Optional[Callable]

If states is provided as an arbitrary PyTree of arrays, this function should be provided to extract the relevant arrays. It should take states and return a tuple of three arrays: position, velocity, and controller output/force.

None
step int

Plot every step-th trial. This is useful when states contains information about a very large set of trials, and we only want to plot a subset of them.

1
trial_specs Optional[AbstractReachTrialSpec]

The specifications for the trials being plotted. If supplied, this is used to plot markers at the initial and goal positions.

None
endpoints Optional[Tuple[Float[Array, 'trial xy=2'], Float[Array, 'trial xy=2']]]

The initial and goal positions for the trials being plotted. Overrides trial_specs.

None
straight_guides bool

If this is True and endpoints are provided, straight dashed lines will be drawn between the initial and goal positions.

False
workspace Optional[Float[Array, 'bounds=2 xy=2']]

The workspace bounds. If provided, the bounds are drawn as a rectangle.

None
cmap_name Optional[str]

The name of the Matplotlib colormap to use across trials.

None
colors Optional[Sequence[str | Tuple[float, ...]]]

A sequence of colors, one for each plotted trial. Overrides cmap_name.

None
color Optional[str | Tuple[float, ...]]

A single color to use for all trials. Overrides cmap_name but not colors.

None
ms int

Marker size for plots of states (trajectories).

3
ms_source int

Marker size for the initial position, if trial_specs/endpoints is provided.

6
ms_target int

Marker size for the goal position.

7
control_label_type str

If 'linear', labels the final (controller output/force) plot as showing Cartesian forces. If 'torques', labels the plot as showing the torques of a two-segment arm.

'linear'
control_labels Optional[Tuple[str, str, str]]

A tuple giving the labels for the title, x-axis, and y-axis of the final (controller output/force) plot. Overrides control_label_type.

None

feedbax.plot.reach_endpoint_dists (trial_specs: AbstractReachTrialSpec,s: int = 7,color: Optional[str] = None,**kwargs) -> Tuple[Figure, Sequence[Axes]]
¤

Plot initial and goal positions along with their distributions.

Parameters:

Name Type Description Default
trial_specs AbstractReachTrialSpec

The specifications for the reach trials.

required
s int

Marker size for the initial and goal positions for all trials.

7
color Optional[str]

The color to use for all points. If None, black or white is automatically chosen based on the current Matplotlib theme.

None

feedbax.plot.activity_sample_units (activities: Float[Array, '*trial time unit'],n_samples: int,unit_includes: Optional[Sequence[int]] = None,cols: int = 2,cmap_name: str = 'tab10',figsize: tuple[float, float] = (6.4, 4.8),*,key: PRNGKeyArray) -> Tuple[Figure, Axes]
¤

Plot activity over multiple trials for a random sample of network units.

The result is a figure with n_samples + len(unit_includes) subplots, arranged in cols columns.

When this function is called more than once in the course of an analysis, if the same key is passed and the network layer has the same number of units—that is, the last dimension of activities has the same size—then the same subset of units will be sampled.

Parameters:

Name Type Description Default
activities Float[Array, '*trial time unit']

The array of trial-by-trial activity over time for each unit in a network layer.

required
n_samples int

The number of units to sample from the layer. Along with unit_includes, this determines the number of subplots in the figure.

required
unit_includes Optional[Sequence[int]]

Indices of specific units to include in the plot, in addition to the n_samples randomly sampled units.

None
cols int

The number of columns in which to arrange the subplots.

2
cmap_name str

The name of the Matplotlib colormap to use. Each trial will be plotted in a different color.

'tab10'
figsize tuple[float, float]

The size of the figure.

(6.4, 4.8)
key PRNGKeyArray

A random key used to sample the units to plot.

required

feedbax.plot.activity_heatmap (activity: Float[Array, 'time unit'],cmap: str = 'viridis')
¤

Plot activity of all units in a network layer over time, on a single trial.

Note

This is a helper for imshow, when the data is an array of neural network unit activities with shape (time, unit).

Example

When working with a SimpleFeedback model built with a SimpleStagedNetwork controller—for example, if we've constructed our model using point_mass_nn—we can plot the activity of the hidden layer of the network:

from feedbax import tree_take

states = task.eval(model, key=key_eval)  # States for all validation trials.
states_trial0 = tree_take(states, 0)
activity_heatmap(states_trial0.net.hidden)

Parameters:

Name Type Description Default
activity Float[Array, 'time unit']

The array of activity over time for each unit in a network layer.

required
cmap str

The name of the Matplotlib colormap to use.

'viridis'

feedbax.plot.joint_pos_trajectory (xy: Float[Array, 'time links ndim=2'],cmap_name: str = 'viridis',length_unit: Optional[str] = None,ax: Optional[Axes] = None,add_root: bool = True,colorbar: bool = True,ms_trace: int = 6,lw_arm: int = 4,workspace: Optional[Float[Array, 'bounds=2 xy=2']] = None) -> tuple[Figure, Axes]
¤

Plot joint position for an \(n\)-link arm over time.

Plots the full arm segments at the beginning, middle, and end of the trial, along with joint traces for all time steps.

Parameters:

Name Type Description Default
xy Float[Array, 'time links ndim=2']

The joint positions over time.

required
cmap_name str

The name of the Matplotlib colormap to use across time.

'viridis'
length_unit Optional[str]

The length unit to display on the axes. By default, axes are unlabeled.

None
ax Optional[Axes]

The Matplotlib axes to plot on. If None, a new figure and axes will be created.

None
add_root bool

Whether to add a root joint to xy; i.e. prepend the origin \((0,0)\) to the array of joint positions.

True
colorbar bool

If True, adds a colorbar to the figure.

True
ms_trace int

Marker size for the joint position traces over time.

6
lw_arm int

Line width of the arm segments.

4
workspace Optional[Float[Array, 'bounds=2 xy=2']]

The workspace bounds. If provided, the bounds are drawn as a rectangle.

None