Pre-built models¤
Feedbax provides models, loss functions, and model-task pairings which can be constructed immediately, without needing to build them out of core components.
Model-task pairings¤
Example
All pairings are provided as TrainingContext
objects, and can be
trained immediately.
For example:
import jax.random as jr
from feedbax.xabdeef import point_mass_nn_simple_reaches
context = point_mass_nn_simple_reaches(key=jr.PRNGKey(0))
model_trained, train_history = context.train()
feedbax.xabdeef.point_mass_nn_simple_reaches
¤
feedbax.xabdeef.point_mass_nn_simple_reaches
point_mass_nn_simple_reaches(n_replicates: int = 1, n_steps: int = 100, dt: float = 0.05, mass: float = 1.0, workspace: Sequence[tuple[float, float]] = ((-1.0, -1.0), (1.0, 1.0)), encoding_size: Optional[int] = None, hidden_size: int = 50, hidden_type: type[eqx.Module] = eqx.nn.GRUCell, where_train: Callable = lambda model: model.step.net, feedback_delay_steps: int = 0, eval_grid_n: int = 1, eval_n_directions: int = 7, *, key: PRNGKeyArray) -> TrainingContext
A simple reach task paired with a neural network controlling a point mass.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_replicates |
int
|
The number of models to generate, with different random initializations. |
1
|
n_steps |
int
|
The number of time steps in each trial. |
100
|
dt |
float
|
The duration of each time step. |
0.05
|
mass |
float
|
The mass of the point mass. |
1.0
|
workspace |
Sequence[tuple[float, float]]
|
The bounds of the rectangular workspace. |
((-1.0, -1.0), (1.0, 1.0))
|
encoding_size |
Optional[int]
|
The number of units in the encoding layer of the
network. Defaults to |
None
|
hidden_size |
int
|
The number of units in the hidden layer of the network. |
50
|
hidden_type |
type[Module]
|
The type of the hidden layer of the network. |
GRUCell
|
where_train |
Callable
|
A function that takes a model and returns the part of the model that should be trained. |
lambda model: net
|
feedback_delay_steps |
int
|
The number of time steps by which sensory feedback is delayed. |
0
|
eval_grid_n |
int
|
The number of grid points for center-out reaches in the validation task. For example, a value of 2 gives a grid of 2x2=4 center-out reach sets. |
1
|
eval_n_directions |
int
|
The number of evenly-spread reach directions per set of center-out reaches. |
7
|
key |
PRNGKeyArray
|
A random key used to initialize the model(s). |
required |
feedbax.xabdeef.TrainingContext
(Module)
¤
A model-task pairing with automatic construction of a
TaskTrainer
instance.
Attributes:
Name | Type | Description |
---|---|---|
model |
AbstractModel
|
The model. |
task |
AbstractTask
|
The task. |
where_train |
Callable
|
A function that takes the model and returns the parts of the model to be trained. |
ensembled |
bool
|
Whether |
train
(
*,
n_batches
: int
,
batch_size
: int
,
learning_rate
: float = 0.01
,
log_step
: Optional[int] = None
,
optimizer_cls
: Callable[..., optax.GradientTransformation] = optax.adam
,
key
: PRNGKeyArray
,
**kwargs
: Any
)
->
tuple[AbstractModel, TaskTrainerHistory]
¤
train
(
*,
n_batches
: int
,
batch_size
: int
,
learning_rate
: float = 0.01
,
log_step
: Optional[int] = None
,
optimizer_cls
: Callable[..., optax.GradientTransformation] = optax.adam
,
key
: PRNGKeyArray
,
**kwargs
: Any
)
->
tuple[AbstractModel, TaskTrainerHistory]
Train the model on the task.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_batches |
int
|
The number of batches of trials to train on. |
required |
batch_size |
int
|
The number of trials per batch. |
required |
learning_rate |
float
|
The learning rate for the optimizer. |
0.01
|
log_step |
Optional[int]
|
The number of batches between logs of training progress.
If |
None
|
optimizer_cls |
Callable[..., GradientTransformation]
|
The class of Optax optimizer to use. |
adam
|
key |
PRNGKeyArray
|
A PRNG key for initializing the model. |
required |
**kwargs |
Any
|
Additional keyword arguments to pass to the |
{}
|
Model only¤
feedbax.xabdeef.models.point_mass_nn
¤
feedbax.xabdeef.models.point_mass_nn
point_mass_nn(task: AbstractTask, n_steps: int = 100, dt: float = 0.05, mass: float = 1.0, encoding_size: Optional[int] = None, hidden_size: int = 50, hidden_type: type[eqx.Module] = eqx.nn.GRUCell, out_nonlinearity: Callable = identity_func, feedback_delay_steps: int = 0, feedback_noise_std: float = 0.025, motor_noise_std: float = 0.025, *, key: PRNGKeyArray)
A point mass controlled by forces output directly by a neural network.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task |
AbstractTask
|
The task the neural network will be trained to perform by controlling the point mass. |
required |
n_steps |
int
|
The number of time steps in each trial. |
100
|
dt |
float
|
The duration of each time step. |
0.05
|
mass |
float
|
The mass of the point mass. |
1.0
|
encoding_size |
Optional[int]
|
The size of the neural network's encoding layer.
If |
None
|
hidden_size |
int
|
The number of units in the network's hidden layer. |
50
|
hidden_type |
type[Module]
|
The network type of the hidden layer. |
GRUCell
|
out_nonlinearity |
Callable
|
The nonlinearity to apply to the network output. |
identity_func
|
feedback_delay_steps |
int
|
The number of time steps to delay sensory feedback provided to the neural network about the point mass. |
0
|
feedback_noise_std |
float
|
The standard deviation of Gaussian noise added to the sensory feedback. |
0.025
|
motor_noise_std |
float
|
The standard deviation of Gaussian noise added to the forces generated by the neural network. |
0.025
|
key |
PRNGKeyArray
|
The random key to use for initializing the model. |
required |
Loss functions¤
feedbax.xabdeef.losses.simple_reach_loss
(
loss_term_weights
: Optional[Mapping[str, float]] = None
,
discount_exp
: int = 6
)
->
CompositeLoss
¤
feedbax.xabdeef.losses.simple_reach_loss
(
loss_term_weights
: Optional[Mapping[str, float]] = None
,
discount_exp
: int = 6
)
->
CompositeLoss
A typical loss function for a simple reaching task.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_term_weights |
Optional[Mapping[str, float]]
|
Maps loss term names to term weights. If |
None
|
discount_exp |
int
|
The exponent of the power function used to discount the position error, back in time from the end of trials. Larger values lead to penalties that are more concentrated at the end of trials. If zero, all time steps are weighted equally. |
6
|