Training¤
feedbax.train.TaskTrainerHistory
(Module)
¤
A record of training history over a call to a
TaskTrainer
instance.
Attributes:
Name | Type | Description |
---|---|---|
loss |
LossDict | Array
|
The training losses. |
learning_rate |
Optional[Array]
|
The optimizer's learning rate. |
model_trainables |
Optional[AbstractModel]
|
The model's trainable parameters. Non-trainable
leaves appear as |
trial_specs |
dict[int, AbstractTaskTrialSpec]
|
The training trial specifications. |
feedbax.train.TaskTrainer
(Module)
¤
Manages resources needed to train models, given task specifications.
__init__
(
optimizer
: optax.GradientTransformation
,
checkpointing
: bool = True
,
chkpt_dir
: str | Path = '/tmp/feedbax-checkpoints'
,
enable_tensorboard
: bool = False
,
tensorboard_logdir
: str | Path = '/tmp/feedbax-tensorboard'
,
model_update_funcs
: Sequence[Callable] = ()
)
¤
__init__
(
optimizer
: optax.GradientTransformation
,
checkpointing
: bool = True
,
chkpt_dir
: str | Path = '/tmp/feedbax-checkpoints'
,
enable_tensorboard
: bool = False
,
tensorboard_logdir
: str | Path = '/tmp/feedbax-tensorboard'
,
model_update_funcs
: Sequence[Callable] = ()
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer |
GradientTransformation
|
The Optax optimizer to use for training. |
required |
checkpointing |
bool
|
Whether to save model checkpoints during training. |
True
|
chkpt_dir |
str | Path
|
The directory in which to save model checkpoints. |
'/tmp/feedbax-checkpoints'
|
enable_tensorboard |
bool
|
Whether to keep logs for Tensorboard. |
False
|
tensorboard_logdir |
str | Path
|
The directory in which to save Tensorboard logs. |
'/tmp/feedbax-tensorboard'
|
model_update_funcs |
Sequence[Callable]
|
At the end of each training step/batch, each of these functions is passed 1) the model, and 2) the model states for all trials in the batch, and returns a model update. These can be used for implementing state-dependent offline learning rules such as batch-averaged Hebbian learning. |
()
|
__call__
(
task
: AbstractTask
,
model
: AbstractModel[StateT]
,
n_batches
: int
,
batch_size
: int
,
where_train
: Callable[[AbstractModel[StateT]], Any]
,
ensembled
: bool = False
,
ensemble_random_trials
: bool = True
,
log_step
: int = 100
,
save_model_trainables
: bool | Int[Array, _] = False
,
save_trial_specs
: Optional[Int[Array, _]] = None
,
toggle_model_update_funcs
: bool | PyTree[Int[Array, _]] = True
,
restore_checkpoint
: bool = False
,
disable_tqdm
: bool = False
,
batch_callbacks
: Optional[Mapping[int, Sequence[Callable]]] = None
,
*,
key
: PRNGKeyArray
)
¤
__call__
(
task
: AbstractTask
,
model
: AbstractModel[StateT]
,
n_batches
: int
,
batch_size
: int
,
where_train
: Callable[[AbstractModel[StateT]], Any]
,
ensembled
: bool = False
,
ensemble_random_trials
: bool = True
,
log_step
: int = 100
,
save_model_trainables
: bool | Int[Array, _] = False
,
save_trial_specs
: Optional[Int[Array, _]] = None
,
toggle_model_update_funcs
: bool | PyTree[Int[Array, _]] = True
,
restore_checkpoint
: bool = False
,
disable_tqdm
: bool = False
,
batch_callbacks
: Optional[Mapping[int, Sequence[Callable]]] = None
,
*,
key
: PRNGKeyArray
)
Train a model on a fixed number of batches of task trials.
Warning
Model checkpointing only saves model parameters, and not the task or other
hyperparameters. That is, we assume that the model and task passed to this
method are, aside from their trainable state, identical to those from the
original training run. This is typically the case when
restore_checkpoint=True
is toggled immediately after the interruption of a
training run, to resume it.
Trying to load a checkpoint as a model at a later time may fail.
Use feedbax.save
and
feedbax.load
for longer-term storage.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
task |
AbstractTask
|
The task to train the model on. |
required |
model |
AbstractModel[StateT]
|
The model—or, vmapped batch/ensemble of models—to train. |
required |
n_batches |
int
|
The number of batches of trials to train on. |
required |
batch_size |
int
|
The number of trials in each batch. |
required |
where_train |
Callable[[AbstractModel[StateT]], Any]
|
Selects the arrays from the model PyTree to be trained. |
required |
ensembled |
bool
|
Should be set to |
False
|
ensemble_random_trials |
bool
|
If |
True
|
log_step |
int
|
Interval at which to evaluate model on the validation set, print losses to the console, log to tensorboard (if enabled), and save checkpoints. |
100
|
save_model_trainables |
bool | Int[Array, _]
|
Whether to return the entire history of the
trainable leaves of the model (e.g. network weights) across
training iterations, as part of the |
False
|
save_trial_specs |
Optional[Int[Array, _]]
|
A 1D array of batch numbers for which to keep trial specifications, and return as part of the training history. |
None
|
toggle_model_update_funcs |
bool | PyTree[Int[Array, _]]
|
Whether to enable the model update functions.
May also pass a PyTree with the same structure as the |
True
|
restore_checkpoint |
bool
|
Whether to attempt to restore from the last saved checkpoint in the checkpoint directory. Typically, this option is toggled to continue a long training run immediately after it was interrupted. |
False
|
disable_tqdm |
bool
|
If |
False
|
batch_callbacks |
Optional[Mapping[int, Sequence[Callable]]]
|
A mapping from batch number to a sequence of functions (without parameters) to be called immediately after the training step is performed for that batch. This can be used (say) for profiling parts of the training run. |
None
|
key |
PRNGKeyArray
|
The random key. |
required |