Why JAX?¤
If you are interested in Feedbax but unfamiliar with JAX—or new to Python—then keep reading, for an overview of some of the tools on which Feedbax is based.
JAX isn't a machine learning framework like PyTorch. It's a more general-purpose tool.
What does JAX provide?
- A NumPy-like API: Many of the things you can write in NumPy, you can also write in JAX—you just have to
import jax.numpy as jnp
instead ofimport numpy as np
. - Just-in-time (JIT) compilation: In many cases, this makes JAX much faster than NumPy.
- Automatic differentiation: We use this to get derivatives of functions—usually, to train models through gradient descent.
- Automatic vectorization: We can easily transform a function that works on single examples, to a function that processes entire batches of data.
- Parallelism: It's easy to split up a large model across multiple devices (e.g. GPUs).
Automatic differention and JIT compilation are features normally found working in the background in ML frameworks, but JAX lets you use them in explicit, arbitrary, powerful ways.
That's why Feedbax is not just built on JAX, but also:
- Equinox, which allows us to define PyTorch-like modules, making it easiers to organize our models;
- Optax, which provides optimizers (like Adam) which you'd normally find in ML frameworks;
- Diffrax, which provides numerical solvers for differential equations.
My favourite part about working with JAX is how nicely it plays with nested containers of data, or PyTrees.
Pytrees¤
Let's start with a list and a dict that contain similar values.
some_list = [1, 2, 3]
some_dict = {'a': 1, 'b': 2, 'c': 3}
In standard Python, a comprehension is a typical way of applying some computation to every value in a list or dict.
[x ** 2 for x in some_list]
{k: x ** 2 for k, x in some_dict.items()}
While list and dict comprehensions are similar, they're not interchangeable. If our data is a list we can use the first method and get a list in return. But as soon as we introduce some data that's stored in a dict, we need to change our code.
Conveniently, JAX provides a function tree_map
that behaves the same way for both lists and dicts.
from jax.tree_util import tree_map
tree_map(lambda x: x ** 2, some_list)
tree_map(lambda x: x ** 2, some_dict)
Python's built-in map
Python includes a built-in function map
which is similar in principle to tree_map
. For example, we can do list(map(lambda x: x**2, some_list))
to get the same result as tree_map(lambda x: x**2, some_list)
.
However:
map
doesn't necessarily return the same type of data structure passed to it. To square all the values in a dict, and return a dict like we did withtree_map
, we'd have to do something likedict(zip(some_dict.keys(), map(lambda x: x**2, some_dict.values())))
. That's much less readable than the JAX solution.map
only works one level deep into a container. This isn't an issue withsome_list
, which is a flat list of numbers. But in many other cases the data we'll want to transform will be in complex, nested trees.
Even better, tree_map
works on nested containers.
import jax.numpy as jnp
some_data = [{'p': [1, 2], 'x': 1.0}, [5, 6, 7, 8, {'y': jnp.array([2, 2, 2])}]]
tree_map(lambda x: x ** 2, some_data)
How does this work? JAX treats both lists and dicts—and any nested structures of lists and dicts—as PyTrees.
A PyTree's leaves are the data it ultimately contains.
from jax.tree_util import tree_leaves, tree_structure
tree_leaves(some_data)
Those leaves are arranged in a tree with a certain structure.
tree_structure(some_data)
Note that the JAX array counts as a leaf, not as part of the tree structure. What does JAX treat as a leaf, and what does it treat as the structure?
By default:
- leaves—AKA leaf nodes—include NumPy and JAX arrays, as well as basic data types like
int
,float
,str
, andbool
; - internal nodes are lists, dicts, and tuples, which JAX recognizes as PyTrees themselves. When JAX encounters these, its default stance is that "nesting continues here"—so it looks inside the node for leaves, or even deeper layers of nodes.
Importantly, we can change what counts as a leaf. Many functions that operate on PyTrees can take an argument is_leaf
.
tree_leaves(some_data, is_leaf=lambda x: isinstance(x, dict))
tree_structure(some_data, is_leaf=lambda x: isinstance(x, dict))
Here, we've told JAX to treat dicts as leaves, rather than as containers. Now, the dicts appear whole and unflattened in the list of leaves, and the PyTree structure reflects this.
To get the leaves and the tree structure in one call, use tree_flatten
:
from jax.tree_util import tree_flatten
leaves, structure = tree_flatten(some_data)
Given both leaves and the structure, tree_unflatten
builds a PyTree. Let's reconstruct the original some_data
:
from jax.tree_util import tree_unflatten
tree_unflatten(structure, leaves)
Because JAX understands the structure of PyTrees, we can apply operations to multiple PyTrees when their structures match.
some_arrays = [
(jnp.array([1, 2]), jnp.array([3, 4])),
jnp.array([5, 6])
]
some_other_arrays = [
(jnp.array([7, 8]), jnp.array([3, 4])),
jnp.array([1, 1])
]
tree_structure(some_arrays) == tree_structure(some_other_arrays)
tree_map(
lambda x, y: x + y,
some_arrays,
some_other_arrays
)
Here, tree_map
works "leafwise" to pick out the arguments to a function: the x
values are the leaves from some_arrays
, and the y
values are the matching leaves from some_other_arrays
.
In this example, the result will be different if we tell JAX to treat tuples as leaves. The first two JAX arrays in each PyTree are contained in a tuple, so the first x
passed to the function will be a pair of tuples, as will the first y
. When we apply +
to two tuples, we concatenate them.
tree_map(
lambda x, y: x + y,
some_arrays,
some_other_arrays,
is_leaf=lambda x: isinstance(x, tuple)
)
The array that's not inside a tuple gets added the same way it did before, because it still counts as a leaf—it's just that now, tuples also count as leaves.
A PyTree of your own¤
A PyTree is any kind of container that JAX knows how to flatten and unflatten. By default, this includes lists, dicts, and tuples.
When the default containers aren't enough for us, we can define our own types of containers, and tell JAX how to flatten and unflatten them. After that, JAX will treat them as PyTrees!
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class TwoValues:
def __init__(self, a, b):
self.a = a
self.b = b
def tree_flatten(self):
return (self.a, self.b), None # leaves, aux_data
@classmethod
def tree_unflatten(cls, aux_data, leaves):
print(aux_data)
return cls(*leaves)
Here, the method tree_flatten
tells JAX how to flatten a TwoValues
object into its leaves, and tree_unflatten
tells how to construct TwoValues
given the leaves.
As its name suggests, the decorator register_pytree_node_class
registers our new PyTree type with JAX.
Now we can use TwoValues
as part of any PyTree:
tree = (TwoValues(12, 45), 3, {'a': TwoValues(4, 5)})
tree_leaves(tree)
tree_structure(tree)
Equinox¤
Most objects in Feedbax are derived from equinox.Module
.
Equinox adds some useful tools to JAX.
In particular, equinox.Module
allows us to easily define classes that are PyTrees, and that combine model parameters with model computations.
import equinox as eqx
import jax
class SomeModel(eqx.Module):
param1: int
param2: jax.Array
def __call__(self, x: float):
return self.param1 + x * self.param2
# Construct an example model.
model = SomeModel(3, jnp.array([1, 2, 3]))
In our class definition, the method __call__
tells Python how a SomeModel
object should behave, when we call it like a function:
model(2.5)
This is a nice way to define and execute our model computation.
Another convenient thing about Equinox Module
is that it's a dataclass
. In a normal Python class, to assign param1
and param2
as instance attributes we'd have to do this:
class SomeModel:
def __init__(self, param1: int, param2: jax.Array):
self.param1 = param1
self.param2 = param2
def __call__(self, x: float):
return self.param1 + x * self.param2
When our class is a dataclass, it automatically defines a default __init__
method like the one above. We just have to define the list of parameters (that is, dataclass fields):
from dataclasses import dataclass
@dataclass
class SomeModel:
param1: int
param2: jax.Array
def __call__(self, x: float):
return self.param1 + x * self.param2
Any class or subclass we define from eqx.Module
will automatically work this way, without needing to add the @dataclass
decorator.
Note
We can still add our own __init__
method to a dataclass if we need to do something fancier than just assigning values to fields.
In case only small modifications to __init__
are needed, it may be convenient to define __post_init__
instead.
The best thing about Equinox modules is that they are PyTrees:
# Get a flattened list of model parameters.
tree_leaves(model)
It turns out this is very useful for structuring models, but that's beyond the scope of this example.
Similarity of Equinox and PyTorch modules
Equinox's Module
is kind of like PyTorch's nn.Module
. However, PyTorch modules:
- are not PyTrees, because PyTorch has no general, built-in concept of PyTrees;
- are not automatically dataclasses, and it can be kind of problematic to convert them;
- define the model computation in the
forward
method, rather than__call__
. Technically though, PyTorch still has to define__call__
in the background to have its module objects behave like functions.
Vectorisation and vmap
¤
The power of pytrees goes much deeper than we've seen here. The core JAX transformations, jax.vmap and jax.grad
Note
If you run into problems with jax.vmap
, try using Equinox's filter_vmap
as we've done above. It does the same thing, but a little more intelligently.
Functions and states¤
JAX plays best with pure functions. Let's see what that means.
Perhaps you are familiar with object-oriented programming, where classes define how objects possess and manipulate their internal states. For example, let's define a type of object that 1) possesses two attributes, and 2) when it's called, returns a result, but also internally updates one of its attributes.
class StatefulFoo:
smee: int
a: int
def __init__(self, a: int):
self.smee = 0
self.a = a
def __call__(self, x: int):
if x > 3:
self.smee = 2
return self.a * x
a = 2
foo = StatefulFoo(a)
x = 1
print("\t\tx\tsmee")
for i in range(7):
x = foo(x)
print(f"Step {i}:\t\t{x}\t{foo.smee}")
Importantly, the internal state—the value of foo.smee
—changes once a certain value is passed to foo
. This is obvious in this case, since we're printing foo.smee
on every step. But under different circumstances, we might not even know it had changed.
Seen as a function, the main thing that foo
does is to return result
. But it also has the side effect of altering foo.smee
.
On the other hand, a pure function does not have side effects. Everything that the function does, is how its input gets turned into its return value.
We can still do what we did with foo.smee
, except that smee
can no longer be hidden. It just needs to be part of the input and output of the function.
class PureFoo:
a: int
def __init__(self, a: int):
self.a = a
def __call__(self, x: int, smee: int):
if x > 3:
smee = 2
return self.a * x, smee
a = 2
foo = PureFoo(a)
smee = 0
x = 1
print("\t\tx\tsmee")
for i in range(7):
x, smee = foo(x, smee)
print(f"Step {i}:\t\t{x}\t{smee}")
@dataclass
class Data:
x: int
smee: int
class PureFoo:
a: int
def __init__(self, a: int):
self.a = a
# Takes Data, and returns Data.
def __call__(self, data: Data) -> Data:
if data.x > 3:
smee = 2
else:
smee = data.smee
return Data(2 * data.x, smee)
a = 2
foo = PureFoo(a)
data = Data(x=1, smee=0)
print("\t\tx\tsmee")
for i in range(7):
data = foo(data)
print(f"Step {i}:\t\t{data.x}\t{data.smee}")
It turns out that as our programs grow complex, this style will work at least as well as the stateful style ever did—and without hiding anything.
In Feedbax, the relationship between a model and its state is like the relationship between PureFoo
and Data
, in this example. A model does not possess state, it operates on it.
Similarly, we never change a state object by directly reassigning its values. For example, in the above example we would never do this:
data = Data(x=1, smee=0)
data.smee = 2
As we'll see shortly, this won't be a problem. We'll just need to define the alteration to data
as some function that takes data
as its input, and constructs the altered version as its output.
Equinox and pure functions¤
It might seem a little odd that we contrasted object oriented programming with purely functional programming, and then we kept defining our "pure function" as a class
!
It's not really odd, though. What matters is that our classes behave like pure functions because of the way we define __call__
. And classes do one very convenient thing for us: they let us keep fixed model parameters (like a
) in the same place as a method that defines the model's computation.
This is essentially what eqx.Module
is for. And it forces us to code in a functional style. Watch what happens if we try to change one of the attributes of an Equinox module:
class Bar(eqx.Module):
a: int
my_bar = Bar(a=3)
my_bar.a = 4
Things are no different if the object tries to change itself directly:
class Baz(eqx.Module):
a: int
def __call__(self, x: int):
self.a = x
my_baz = Baz(a=3)
my_baz(4)
In other words, Equinox modules are immutable. Immutability goes hand in hand with pure functions, because it ensures that the internal state of our objects cannot be altered in the background.
def foo_update(foo: PureFoo) -> PureFoo:
a = foo.a + 1
return PureFoo(a)
foo = PureFoo(a=2)
foo_new = foo_update(foo)
print(f"Old a: {foo.a}")
print(f"New a: {foo_new.a}")
This time we've just defined a plain old Python function with def
, but we could also have using an Equinox module to define this function, if it had parameters of its own to remember.
JAX arrays are immutable¤
Here's something we can do in NumPy:
import numpy as np
some_array = np.zeros((3, 3))
# Modify the array in-place.
some_array[0, 1:] = 5
some_array
The same thing doesn't work in JAX.
import jax.numpy as jnp
some_array = jnp.zeros((3, 3))
some_array[0, 1:] = 5
JAX arrays are immutable! That means we can't reach in and change an array in-place. We always have to perform some transformation that returns a new array.
JAX provides the at-set
syntax for assigning a value to an index.
some_array = some_array.at[0, 1:].set(5)
some_array
The right hand side of the assignment can be seen as a function that takes some_array
, and returns a new array object with the requested alteration.
In some cases it might appear that JAX performs in-place operations. For example, in NumPy we can do in-place addition like this:
another_array = np.zeros((2, 2))
another_array += 3
another_array
We can tell that the operation is in-place because another_array
has the same object ID as before:
id_before = id(another_array)
another_array += 10
id_after = id(another_array)
id_before == id_after
On the other hand, this is not the case in JAX:
# Now using jnp, not np!
another_array = jnp.zeros((2, 2))
id_before = id(another_array)
another_array += 10
id_after = id(another_array)
id_before == id_after
A different ID means a different Python object. In other words, JAX treats another_array += 10
like it would treat the purely functional another_array = another_array + 10
, and not as an in-place update.
To be clear, you should write another_array = another_array + 10
.
Immutability in Feedbax: performing surgery¤
In Feedbax, models and states can be really big PyTrees. Often we want to change just one part of them. But we don't want to keep writing huge functions that reconstruct the entire model, every time we want to replace just one piece.
Thankfully, Equinox provides a general-purpose function that can perform surgery.
Let's start with a pre-built model.
import jax
from feedbax.xabdeef import point_mass_nn_simple_reaches
context = point_mass_nn_simple_reaches(key=jax.random.PRNGKey(0))
model = context.model # Shorthand
This model has a point mass of mass \(1.0\) as its skeleton.
model.step.mechanics.plant.skeleton
As expected, if we try to directly alter the model to use a point mass of mass \(5.0\), an error is raised.
from feedbax.mechanics.skeleton import PointMass
# Try to replace the entire point mass
model.step.mechanics.plant.skeleton = PointMass(5.0)
# Or just try to change the mass
model.step.mechanics.plant.skeleton.mass = 5.0
Instead, we use the tree_at
function from Equinox. This is a function that takes a PyTree, and replaces a piece of it.
import equinox as eqx
model_heavy = eqx.tree_at(
lambda m: m.step.mechanics.plant.skeleton,
model,
PointMass(5.0)
)
The first argument to tree_at
is a "locator" function: when we pass it the model, it returns the part of the model we want to replace.
Lambda functions
Using Python's lambda
syntax lets us define the function inline. This isn't strictly necessary, but it's common practice in JAX when we need to define functions to pick out parts of PyTrees.
The second argument is just model
, which is the model we want to alter.
The third argument is the replacement part.
Here, the model with the modifications is assigned to model_heavy
, and we can still refer to the original model as model
.
Random number generation¤
One other way that JAX differs from NumPy is how it handles the generation of random numbers.
In NumPy, random number generators are stateful. We can see this by calling one more than once.
np.random.random()
np.random.random()
The two numbers are different, but the input to the function was not: in both cases, we passed no arguments.
If a function's output differs when the input remains the same, it's not a pure function. In this case, the number changed because it was based on a state variable that changed in the background, like smee
did earlier.
These aren't actually random numbers, they're pseudo-random: they're the outputs of a deterministic function that varies wildly with its input. In NumPy, we can control where it starts from by setting the random seed:
seed = 1234
np.random.seed(seed)
np.random.random()
np.random.random()
np.random.seed(seed)
np.random.random()
np.random.random()
Whenever we set the seed, the random numbers start from the same point. This makes our subsequent calls reproducible, in principle. However, because the state of the random number generator continues to change in the background, our code may stop being reproducible if at any point during the execution of our program, some other program makes even a single call to the random number generator, and changes its state.
Note
The situation is similar in PyTorch, and the above example can be repeated with torch.rand
and torch.manual_seed
.
JAX takes a totally functional and transparent approach to random numbers. Whenever we want to generate a random number, we have to pass a key.
import jax.random as jr
seed = 5678
key = jr.PRNGKey(seed)
jr.uniform(key)
If we call a random generator a second time with the same key, it returns the same result.
jr.uniform(key)
When we want to generate a new random number, we get a new key with split
.
key1, key2 = jr.split(key)
print(jr.uniform(key1))
print(jr.uniform(key2))
This forces us to always be clear about the logic of how random numbers are generated. In JAX it's typical to see something like:
def generate_data(key):
key_uniform, key_normal = jr.split(key)
data_uniform = jr.uniform(key_uniform, (2, 4))
data_normal = jr.normal(key_normal, (2, 4))
return {'uniform': data_uniform, 'normal': data_normal}
Normally we only pass a single key to a function, if the function needs to generate random numbers internally. The function defines how many keys it actually needs, by splitting the one we send it.
key = jr.PRNGKey(seed)
generate_data(key)
# Use the same key for reproducible results.
generate_data(key)
Compared to random number generation in NumPy and PyTorch, I find that JAX takes a small amount of extra effort—in most cases, writing at most 1 extra line of code per function, to split keys—but I end up feeling significantly more comfortable that my code's output is actually reproducible.