Skip to content

Training¤

feedbax.train.TaskTrainerHistory (Module) ¤

A record of training history over a call to a TaskTrainer instance.

Attributes:

Name Type Description
loss LossDict | Array

The training losses.

learning_rate Optional[Array]

The optimizer's learning rate.

model_trainables Optional[AbstractModel]

The model's trainable parameters. Non-trainable leaves appear as None.

trial_specs dict[int, AbstractTaskTrialSpec]

The training trial specifications.

feedbax.train.TaskTrainer (Module) ¤

Manages resources needed to train models, given task specifications.

__init__ (optimizer: optax.GradientTransformation,checkpointing: bool = True,chkpt_dir: str | Path = '/tmp/feedbax-checkpoints',enable_tensorboard: bool = False,tensorboard_logdir: str | Path = '/tmp/feedbax-tensorboard',model_update_funcs: Sequence[Callable] = ())
¤

Parameters:

Name Type Description Default
optimizer GradientTransformation

The Optax optimizer to use for training.

required
checkpointing bool

Whether to save model checkpoints during training.

True
chkpt_dir str | Path

The directory in which to save model checkpoints.

'/tmp/feedbax-checkpoints'
enable_tensorboard bool

Whether to keep logs for Tensorboard.

False
tensorboard_logdir str | Path

The directory in which to save Tensorboard logs.

'/tmp/feedbax-tensorboard'
model_update_funcs Sequence[Callable]

At the end of each training step/batch, each of these functions is passed 1) the model, and 2) the model states for all trials in the batch, and returns a model update. These can be used for implementing state-dependent offline learning rules such as batch-averaged Hebbian learning.

()
__call__ (task: AbstractTask,model: AbstractModel[StateT],n_batches: int,batch_size: int,where_train: Callable[[AbstractModel[StateT]], Any],ensembled: bool = False,ensemble_random_trials: bool = True,log_step: int = 100,save_model_trainables: bool | Int[Array, _] = False,save_trial_specs: Optional[Int[Array, _]] = None,toggle_model_update_funcs: bool | PyTree[Int[Array, _]] = True,restore_checkpoint: bool = False,disable_tqdm: bool = False,batch_callbacks: Optional[Mapping[int, Sequence[Callable]]] = None,*,key: PRNGKeyArray)
¤

Train a model on a fixed number of batches of task trials.

Warning

Model checkpointing only saves model parameters, and not the task or other hyperparameters. That is, we assume that the model and task passed to this method are, aside from their trainable state, identical to those from the original training run. This is typically the case when restore_checkpoint=True is toggled immediately after the interruption of a training run, to resume it.

Trying to load a checkpoint as a model at a later time may fail. Use feedbax.save and feedbax.load for longer-term storage.

Parameters:

Name Type Description Default
task AbstractTask

The task to train the model on.

required
model AbstractModel[StateT]

The model—or, vmapped batch/ensemble of models—to train.

required
n_batches int

The number of batches of trials to train on.

required
batch_size int

The number of trials in each batch.

required
where_train Callable[[AbstractModel[StateT]], Any]

Selects the arrays from the model PyTree to be trained.

required
ensembled bool

Should be set to True if model is a vmapped ensemble of models that should be trained in parallel.

False
ensemble_random_trials bool

If False, every model in an ensemble will be trained on the same batches of trials. Otherwise, a distinct batch will be generated for each model. Has no effect if ensembled is False.

True
log_step int

Interval at which to evaluate model on the validation set, print losses to the console, log to tensorboard (if enabled), and save checkpoints.

100
save_model_trainables bool | Int[Array, _]

Whether to return the entire history of the trainable leaves of the model (e.g. network weights) across training iterations, as part of the TaskTrainerHistory object. May also pass a 1D array of batch numbers on which to keep history.

False
save_trial_specs Optional[Int[Array, _]]

A 1D array of batch numbers for which to keep trial specifications, and return as part of the training history.

None
toggle_model_update_funcs bool | PyTree[Int[Array, _]]

Whether to enable the model update functions. May also pass a PyTree with the same structure as the TaskTrainer's model_update_funcs attribute, where each leaf is a 1D array of batch numbers on which to enable the respective function. If the model_update_funcs attribute is empty, this argument is ignored.

True
restore_checkpoint bool

Whether to attempt to restore from the last saved checkpoint in the checkpoint directory. Typically, this option is toggled to continue a long training run immediately after it was interrupted.

False
disable_tqdm bool

If True, tqdm progress bars are disabled.

False
batch_callbacks Optional[Mapping[int, Sequence[Callable]]]

A mapping from batch number to a sequence of functions (without parameters) to be called immediately after the training step is performed for that batch. This can be used (say) for profiling parts of the training run.

None
key PRNGKeyArray

The random key.

required