Skip to content

Pre-built models¤

Feedbax provides models, loss functions, and model-task pairings which can be constructed immediately, without needing to build them out of core components.

Model-task pairings¤

Example

All pairings are provided as TrainingContext objects, and can be trained immediately.

For example:

import jax.random as jr
from feedbax.xabdeef import point_mass_nn_simple_reaches

context = point_mass_nn_simple_reaches(key=jr.PRNGKey(0))
model_trained, train_history = context.train()

feedbax.xabdeef.point_mass_nn_simple_reaches
¤

point_mass_nn_simple_reaches(n_replicates: int = 1, n_steps: int = 100, dt: float = 0.05, mass: float = 1.0, workspace: Sequence[tuple[float, float]] = ((-1.0, -1.0), (1.0, 1.0)), encoding_size: Optional[int] = None, hidden_size: int = 50, hidden_type: type[eqx.Module] = eqx.nn.GRUCell, where_train: Callable = lambda model: model.step.net, feedback_delay_steps: int = 0, eval_grid_n: int = 1, eval_n_directions: int = 7, *, key: PRNGKeyArray) -> TrainingContext

A simple reach task paired with a neural network controlling a point mass.

Parameters:

Name Type Description Default
n_replicates int

The number of models to generate, with different random initializations.

1
n_steps int

The number of time steps in each trial.

100
dt float

The duration of each time step.

0.05
mass float

The mass of the point mass.

1.0
workspace Sequence[tuple[float, float]]

The bounds of the rectangular workspace.

((-1.0, -1.0), (1.0, 1.0))
encoding_size Optional[int]

The number of units in the encoding layer of the network. Defaults to None (no encoding layer).

None
hidden_size int

The number of units in the hidden layer of the network.

50
hidden_type type[Module]

The type of the hidden layer of the network.

GRUCell
where_train Callable

A function that takes a model and returns the part of the model that should be trained.

lambda model: net
feedback_delay_steps int

The number of time steps by which sensory feedback is delayed.

0
eval_grid_n int

The number of grid points for center-out reaches in the validation task. For example, a value of 2 gives a grid of 2x2=4 center-out reach sets.

1
eval_n_directions int

The number of evenly-spread reach directions per set of center-out reaches.

7
key PRNGKeyArray

A random key used to initialize the model(s).

required

feedbax.xabdeef.TrainingContext (Module) ¤

A model-task pairing with automatic construction of a TaskTrainer instance.

Attributes:

Name Type Description
model AbstractModel

The model.

task AbstractTask

The task.

where_train Callable

A function that takes the model and returns the parts of the model to be trained.

ensembled bool

Whether model is an ensemble of models.

train (*,n_batches: int,batch_size: int,learning_rate: float = 0.01,log_step: Optional[int] = None,optimizer_cls: Callable[..., optax.GradientTransformation] = optax.adam,key: PRNGKeyArray,**kwargs: Any) -> tuple[AbstractModel, TaskTrainerHistory]
¤

Train the model on the task.

Parameters:

Name Type Description Default
n_batches int

The number of batches of trials to train on.

required
batch_size int

The number of trials per batch.

required
learning_rate float

The learning rate for the optimizer.

0.01
log_step Optional[int]

The number of batches between logs of training progress. If None, 10 evenly-spaced logs will be made along the training run.

None
optimizer_cls Callable[..., GradientTransformation]

The class of Optax optimizer to use.

adam
key PRNGKeyArray

A PRNG key for initializing the model.

required
**kwargs Any

Additional keyword arguments to pass to the TaskTrainer.

{}

Model only¤

feedbax.xabdeef.models.point_mass_nn
¤

point_mass_nn(task: AbstractTask, n_steps: int = 100, dt: float = 0.05, mass: float = 1.0, encoding_size: Optional[int] = None, hidden_size: int = 50, hidden_type: type[eqx.Module] = eqx.nn.GRUCell, out_nonlinearity: Callable = identity_func, feedback_delay_steps: int = 0, feedback_noise_std: float = 0.025, motor_noise_std: float = 0.025, *, key: PRNGKeyArray)

A point mass controlled by forces output directly by a neural network.

Parameters:

Name Type Description Default
task AbstractTask

The task the neural network will be trained to perform by controlling the point mass.

required
n_steps int

The number of time steps in each trial.

100
dt float

The duration of each time step.

0.05
mass float

The mass of the point mass.

1.0
encoding_size Optional[int]

The size of the neural network's encoding layer. If None, no encoding layer is used.

None
hidden_size int

The number of units in the network's hidden layer.

50
hidden_type type[Module]

The network type of the hidden layer.

GRUCell
out_nonlinearity Callable

The nonlinearity to apply to the network output.

identity_func
feedback_delay_steps int

The number of time steps to delay sensory feedback provided to the neural network about the point mass.

0
feedback_noise_std float

The standard deviation of Gaussian noise added to the sensory feedback.

0.025
motor_noise_std float

The standard deviation of Gaussian noise added to the forces generated by the neural network.

0.025
key PRNGKeyArray

The random key to use for initializing the model.

required

Loss functions¤

feedbax.xabdeef.losses.simple_reach_loss (loss_term_weights: Optional[Mapping[str, float]] = None,discount_exp: int = 6) -> CompositeLoss
¤

A typical loss function for a simple reaching task.

Parameters:

Name Type Description Default
loss_term_weights Optional[Mapping[str, float]]

Maps loss term names to term weights. If None, a typical set of default weights is used.

None
discount_exp int

The exponent of the power function used to discount the position error, back in time from the end of trials. Larger values lead to penalties that are more concentrated at the end of trials. If zero, all time steps are weighted equally.

6