Saving and loading¤
We often want to save the parameters of a trained model, so that we can use it again later without needing to re-train.
All Feedbax components—including, automatically and for free, any that you might write—are PyTrees: they are represented as tree-structured data. Equinox is able to save this data to a file.
Feedbax provides some functions to make this slightly easier. However, you can also learn to use the Equinox functions tree_serialise_leaves
and tree_deserialize_leaves
, if you prefer a different scheme for saving and loading.
Here's an example of how to use the functions provided by Feedbax.
We'll start by writing a function that sets up the components we're going to want to save.
import jax
from feedbax.task import SimpleReaches
from feedbax.xabdeef.losses import simple_reach_loss
from feedbax.xabdeef.models import point_mass_nn
# The leading asterisk forces all the arguments to be passed as keyword arguments
def setup(*, workspace, n_steps, dt, hidden_size, key):
task = SimpleReaches(
loss_func=simple_reach_loss(),
workspace=workspace,
n_steps=n_steps
)
model = point_mass_nn(task, dt=dt, hidden_size=hidden_size, key=key)
return task, model
Use this function to construct task and model objects.
Let's keep all the arguments we pass to setup
together in a dictionary called hyperparameters
, since we'll want to save them along with the model.
hyperparameters = dict(
workspace=((-1., -1.), # Workspace bounds ((x_min, y_min), (x_max, y_max)
(1., 1.)),
n_steps=100, # Number of time steps per trial
dt=0.05, # Duration of a time step
hidden_size=50, # Number of units in the hidden layer of the controller
)
key_init, key_train, key_eval = jax.random.split(jax.random.PRNGKey(0), 3)
task, model = setup(**hyperparameters, key=key_init)
Now train the model to perform the task. We'll just do a short run of 500 batches, since in this case we're interested in whether a model can be successfully saved and reloaded, not whether it's fully converged on a solution.
import optax
from feedbax.train import TaskTrainer
trainer = TaskTrainer(
optimizer=optax.adam(learning_rate=1e-2)
)
model_trained, _ = trainer(
task=task,
model=model,
n_batches=500,
batch_size=250,
where_train=lambda model: model.step.net,
key=key_train,
)
Here's how we can save both the task and the trained model to a file.
from feedbax import save, load
save_path = "example_save.eqx"
save(
save_path,
(task, model_trained),
hyperparameters=hyperparameters,
)
Now, immediately load it again.
task_loaded, model_loaded = load(save_path, setup)
How does this work?
The short version is that setup
is a reproducible way to construct our model and our task, which is reused by load
to reconstruct them in exactly the same way, before replacing their parameters with the saved (trained) ones from the file created by save
.
The slightly longer version: the shape of the data we pass to save
has to match what's constructed by the setup
function. In this case, that's a tuple task, model
. Between save time and load time, of course we need to store the file that was created by save
. But we also store the source code for setup
. At load time, we pass setup
to load
, along with the location of the saved file. First, load
uses the hyperparameters
stored in the file to call setup
the same way we originally did, producing a skeleton task, model
, which is then filled with the parameters from the task, model
that had been passed to save
.
Don't use pickle
Python includes the module pickle
, which can save and load entire Python objects without needing to specify, at the time of loading, how those objects were created—that is, it doesn't require us to hold on to a function like setup
between save time and load time. This seems convenient, but it is not good practice in general:
- Upon loading, Python will automatically execute code found in a pickle file, in order to reconstruct the pickled objects. This is a security issue. If someone shares a pickled model with you, they (or an interloper) could insert harmful code into the pickle file, and you may not know it's there until you run it.
- Some of the components we use, such as
lambda
expressions, are not compatible withpickle
. - You probably still have to keep track of how the objects in the pickle were created, for your research to be reproducible in detail. You could pickle
setup
as well—you could try to pickle all of your code, even—except you'd still run into the preceding issues. In the long run, the more explicit and organized solution is preferable, though it takes a little more work to start.
See the Equinox documentation for a similar discussion of these limitations.
In general, save
and load
can be used to save the contents of any PyTree, so long as it's the same kind of PyTree returned by the setup function. That doesn't have to be (task, model)
. Sometimes we might just care about storing the model, and not the task.
def setup(*, .., key):
model = ..
return model
model = setup(**hyperparameters, key=key_init)
# Train the model (omitted).
# ...
save(
save_path,
model,
hyperparameters=hyperparameters,
)
model = load(save_path, setup)
Or, we could achieve the same thing by wrapping an existing setup function to make sure its return value matches what we pass to save
.
save_path_model_only = "example_save_model_only.eqx"
save(
save_path_model_only,
model_trained,
hyperparameters=hyperparameters,
)
# Wrapper that turns a function that returns `task, model`
# into a function that returns `model`
def setup_model_only(**kwargs):
_, model = setup(**kwargs)
return model
model_loaded = load(
save_path_model_only,
setup_model_only,
)
Finally, we can compare the trained model before and after reloading, to show that the process of saving to a file preserves the model.
First, the original:
from feedbax.plot import plot_reach_trajectories
states = task.eval(model_trained, key=key_eval)
_ = plot_reach_trajectories(
states,
trial_specs = task.validation_trials,
)
And the reloaded copy:
states = task_loaded.eval(model_loaded, key=key_eval)
_ = plot_reach_trajectories(
states,
trial_specs = task.validation_trials,
)