Skip to content

Why JAX?¤

If you are interested in Feedbax but unfamiliar with JAX—or new to Python—then keep reading, for an overview of some of the tools on which Feedbax is based.

JAX isn't a machine learning framework like PyTorch. It's a more general-purpose tool.

What does JAX provide?
  • A NumPy-like API: Many of the things you can write in NumPy, you can also write in JAX—you just have to import jax.numpy as jnp instead of import numpy as np.
  • Just-in-time (JIT) compilation: In many cases, this makes JAX much faster than NumPy.
  • Automatic differentiation: We use this to get derivatives of functions—usually, to train models through gradient descent.
  • Automatic vectorization: We can easily transform a function that works on single examples, to a function that processes entire batches of data.
  • Parallelism: It's easy to split up a large model across multiple devices (e.g. GPUs).

Automatic differention and JIT compilation are features normally found working in the background in ML frameworks, but JAX lets you use them in explicit, arbitrary, powerful ways.

That's why Feedbax is not just built on JAX, but also:

  • Equinox, which allows us to define PyTorch-like modules, making it easiers to organize our models;
  • Optax, which provides optimizers (like Adam) which you'd normally find in ML frameworks;
  • Diffrax, which provides numerical solvers for differential equations.

My favourite part about working with JAX is how nicely it plays with nested containers of data, or PyTrees.

Pytrees¤

Let's start with a list and a dict that contain similar values.

some_list = [1, 2, 3]

some_dict = {'a': 1, 'b': 2, 'c': 3}

In standard Python, a comprehension is a typical way of applying some computation to every value in a list or dict.

[x ** 2 for x in some_list]
[1, 4, 9]
{k: x ** 2 for k, x in some_dict.items()}
{'a': 1, 'b': 4, 'c': 9}

While list and dict comprehensions are similar, they're not interchangeable. If our data is a list we can use the first method and get a list in return. But as soon as we introduce some data that's stored in a dict, we need to change our code.

Conveniently, JAX provides a function tree_map that behaves the same way for both lists and dicts.

from jax.tree_util import tree_map

tree_map(lambda x: x ** 2, some_list)
[1, 4, 9]
tree_map(lambda x: x ** 2, some_dict)
{'a': 1, 'b': 4, 'c': 9}
Python's built-in map

Python includes a built-in function map which is similar in principle to tree_map. For example, we can do list(map(lambda x: x**2, some_list)) to get the same result as tree_map(lambda x: x**2, some_list).

However:

  • map doesn't necessarily return the same type of data structure passed to it. To square all the values in a dict, and return a dict like we did with tree_map, we'd have to do something like dict(zip(some_dict.keys(), map(lambda x: x**2, some_dict.values()))). That's much less readable than the JAX solution.
  • map only works one level deep into a container. This isn't an issue with some_list, which is a flat list of numbers. But in many other cases the data we'll want to transform will be in complex, nested trees.

Even better, tree_map works on nested containers.

import jax.numpy as jnp

some_data = [{'p': [1, 2], 'x': 1.0}, [5, 6, 7, 8, {'y': jnp.array([2, 2, 2])}]]

tree_map(lambda x: x ** 2, some_data)
[{'p': [1, 4], 'x': 1.0},
 [25, 36, 49, 64, {'y': Array([4, 4, 4], dtype=int32)}]]

How does this work? JAX treats both lists and dicts—and any nested structures of lists and dicts—as PyTrees.

A PyTree's leaves are the data it ultimately contains.

from jax.tree_util import tree_leaves, tree_structure

tree_leaves(some_data)
[1, 2, 1.0, 5, 6, 7, 8, Array([2, 2, 2], dtype=int32)]

Those leaves are arranged in a tree with a certain structure.

tree_structure(some_data)
PyTreeDef([{'p': [*, *], 'x': *}, [*, *, *, *, {'y': *}]])

Note that the JAX array counts as a leaf, not as part of the tree structure. What does JAX treat as a leaf, and what does it treat as the structure?

By default:

  • leaves—AKA leaf nodes—include NumPy and JAX arrays, as well as basic data types like int, float, str, and bool;
  • internal nodes are lists, dicts, and tuples, which JAX recognizes as PyTrees themselves. When JAX encounters these, its default stance is that "nesting continues here"—so it looks inside the node for leaves, or even deeper layers of nodes.

Importantly, we can change what counts as a leaf. Many functions that operate on PyTrees can take an argument is_leaf.

tree_leaves(some_data, is_leaf=lambda x: isinstance(x, dict))
[{'p': [1, 2], 'x': 1.0}, 5, 6, 7, 8, {'y': Array([2, 2, 2], dtype=int32)}]
tree_structure(some_data, is_leaf=lambda x: isinstance(x, dict))
PyTreeDef([*, [*, *, *, *, *]])

Here, we've told JAX to treat dicts as leaves, rather than as containers. Now, the dicts appear whole and unflattened in the list of leaves, and the PyTree structure reflects this.

To get the leaves and the tree structure in one call, use tree_flatten:

from jax.tree_util import tree_flatten

leaves, structure = tree_flatten(some_data)

Given both leaves and the structure, tree_unflatten builds a PyTree. Let's reconstruct the original some_data:

from jax.tree_util import tree_unflatten

tree_unflatten(structure, leaves)
[{'p': [1, 2], 'x': 1.0}, [5, 6, 7, 8, {'y': Array([2, 2, 2], dtype=int32)}]]

Because JAX understands the structure of PyTrees, we can apply operations to multiple PyTrees when their structures match.

some_arrays = [
    (jnp.array([1, 2]), jnp.array([3, 4])),
    jnp.array([5, 6])
]

some_other_arrays = [
    (jnp.array([7, 8]), jnp.array([3, 4])),
    jnp.array([1, 1])
]

tree_structure(some_arrays) == tree_structure(some_other_arrays)
True
tree_map(
    lambda x, y: x + y,
    some_arrays,
    some_other_arrays
)
[(Array([ 8, 10], dtype=int32), Array([6, 8], dtype=int32)),
 Array([6, 7], dtype=int32)]

Here, tree_map works "leafwise" to pick out the arguments to a function: the x values are the leaves from some_arrays, and the y values are the matching leaves from some_other_arrays.

In this example, the result will be different if we tell JAX to treat tuples as leaves. The first two JAX arrays in each PyTree are contained in a tuple, so the first x passed to the function will be a pair of tuples, as will the first y. When we apply + to two tuples, we concatenate them.

tree_map(
    lambda x, y: x + y,
    some_arrays,
    some_other_arrays,
    is_leaf=lambda x: isinstance(x, tuple)
)
[(Array([1, 2], dtype=int32),
  Array([3, 4], dtype=int32),
  Array([7, 8], dtype=int32),
  Array([3, 4], dtype=int32)),
 Array([6, 7], dtype=int32)]

The array that's not inside a tuple gets added the same way it did before, because it still counts as a leaf—it's just that now, tuples also count as leaves.

A PyTree of your own¤

A PyTree is any kind of container that JAX knows how to flatten and unflatten. By default, this includes lists, dicts, and tuples.

When the default containers aren't enough for us, we can define our own types of containers, and tell JAX how to flatten and unflatten them. After that, JAX will treat them as PyTrees!

from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class TwoValues:
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def tree_flatten(self):
        return (self.a, self.b), None  # leaves, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, leaves):
        print(aux_data)
        return cls(*leaves)

Here, the method tree_flatten tells JAX how to flatten a TwoValues object into its leaves, and tree_unflatten tells how to construct TwoValues given the leaves.

As its name suggests, the decorator register_pytree_node_class registers our new PyTree type with JAX.

Now we can use TwoValues as part of any PyTree:

tree = (TwoValues(12, 45), 3, {'a': TwoValues(4, 5)})

tree_leaves(tree)
[12, 45, 3, 4, 5]
tree_structure(tree)
PyTreeDef((CustomNode(TwoValues[None], [*, *]), *, {'a': CustomNode(TwoValues[None], [*, *])}))

Equinox¤

Most objects in Feedbax are derived from equinox.Module.

Equinox adds some useful tools to JAX.

In particular, equinox.Module allows us to easily define classes that are PyTrees, and that combine model parameters with model computations.

import equinox as eqx
import jax


class SomeModel(eqx.Module):
    param1: int
    param2: jax.Array

    def __call__(self, x: float):
        return self.param1 + x * self.param2


# Construct an example model.
model = SomeModel(3, jnp.array([1, 2, 3]))

In our class definition, the method __call__ tells Python how a SomeModel object should behave, when we call it like a function:

model(2.5)
Array([ 5.5,  8. , 10.5], dtype=float32, weak_type=True)

This is a nice way to define and execute our model computation.

Another convenient thing about Equinox Module is that it's a dataclass. In a normal Python class, to assign param1 and param2 as instance attributes we'd have to do this:

class SomeModel:
    def __init__(self, param1: int, param2: jax.Array):
        self.param1 = param1
        self.param2 = param2

    def __call__(self, x: float):
        return self.param1 + x * self.param2

When our class is a dataclass, it automatically defines a default __init__ method like the one above. We just have to define the list of parameters (that is, dataclass fields):

from dataclasses import dataclass

@dataclass
class SomeModel:
    param1: int
    param2: jax.Array

    def __call__(self, x: float):
        return self.param1 + x * self.param2

Any class or subclass we define from eqx.Module will automatically work this way, without needing to add the @dataclass decorator.

Note

We can still add our own __init__ method to a dataclass if we need to do something fancier than just assigning values to fields.

In case only small modifications to __init__ are needed, it may be convenient to define __post_init__ instead.

The best thing about Equinox modules is that they are PyTrees:

# Get a flattened list of model parameters.
tree_leaves(model)
[3, Array([1, 2, 3], dtype=int32)]

It turns out this is very useful for structuring models, but that's beyond the scope of this example.

Similarity of Equinox and PyTorch modules

Equinox's Module is kind of like PyTorch's nn.Module. However, PyTorch modules:

  • are not PyTrees, because PyTorch has no general, built-in concept of PyTrees;
  • are not automatically dataclasses, and it can be kind of problematic to convert them;
  • define the model computation in the forward method, rather than __call__. Technically though, PyTorch still has to define __call__ in the background to have its module objects behave like functions.

Vectorisation and vmap¤

The power of pytrees goes much deeper than we've seen here. The core JAX transformations, jax.vmap and jax.grad

Note

If you run into problems with jax.vmap, try using Equinox's filter_vmap as we've done above. It does the same thing, but a little more intelligently.

Functions and states¤

JAX plays best with pure functions. Let's see what that means.

Perhaps you are familiar with object-oriented programming, where classes define how objects possess and manipulate their internal states. For example, let's define a type of object that 1) possesses two attributes, and 2) when it's called, returns a result, but also internally updates one of its attributes.

class StatefulFoo:
    smee: int
    a: int

    def __init__(self, a: int):
        self.smee = 0
        self.a = a

    def __call__(self, x: int):

        if x > 3:
            self.smee = 2

        return self.a * x


a = 2
foo = StatefulFoo(a)
x = 1

print("\t\tx\tsmee")

for i in range(7):
    x = foo(x)

    print(f"Step {i}:\t\t{x}\t{foo.smee}")
      x   smee
Step 0:     2   0
Step 1:     4   0
Step 2:     8   2
Step 3:     16  2
Step 4:     32  2
Step 5:     64  2
Step 6:     128 2

Importantly, the internal state—the value of foo.smee—changes once a certain value is passed to foo. This is obvious in this case, since we're printing foo.smee on every step. But under different circumstances, we might not even know it had changed.

Seen as a function, the main thing that foo does is to return result. But it also has the side effect of altering foo.smee.

On the other hand, a pure function does not have side effects. Everything that the function does, is how its input gets turned into its return value.

We can still do what we did with foo.smee, except that smee can no longer be hidden. It just needs to be part of the input and output of the function.

class PureFoo:
    a: int

    def __init__(self, a: int):
        self.a = a

    def __call__(self, x: int, smee: int):

        if x > 3:
            smee = 2

        return self.a * x, smee

a = 2
foo = PureFoo(a)
smee = 0
x = 1

print("\t\tx\tsmee")

for i in range(7):
    x, smee = foo(x, smee)

    print(f"Step {i}:\t\t{x}\t{smee}")
      x   smee
Step 0:     2   0
Step 1:     4   0
Step 2:     8   2
Step 3:     16  2
Step 4:     32  2
Step 5:     64  2
Step 6:     128 2

Maybe this doesn't seem as nice as StatefulFoo, but it is totally transparent. And if we keep building up our programs in this way, it forces us to start adding more structure to the inputs and outputs of our functions.

@dataclass
class Data:
    x: int
    smee: int


class PureFoo:
    a: int

    def __init__(self, a: int):
        self.a = a

    # Takes Data, and returns Data.
    def __call__(self, data: Data) -> Data:

        if data.x > 3:
            smee = 2
        else:
            smee = data.smee

        return Data(2 * data.x, smee)


a = 2
foo = PureFoo(a)
data = Data(x=1, smee=0)

print("\t\tx\tsmee")

for i in range(7):
    data = foo(data)

    print(f"Step {i}:\t\t{data.x}\t{data.smee}")
      x   smee
Step 0:     2   0
Step 1:     4   0
Step 2:     8   2
Step 3:     16  2
Step 4:     32  2
Step 5:     64  2
Step 6:     128 2

It turns out that as our programs grow complex, this style will work at least as well as the stateful style ever did—and without hiding anything.

In Feedbax, the relationship between a model and its state is like the relationship between PureFoo and Data, in this example. A model does not possess state, it operates on it.

Similarly, we never change a state object by directly reassigning its values. For example, in the above example we would never do this:

data = Data(x=1, smee=0)
data.smee = 2

As we'll see shortly, this won't be a problem. We'll just need to define the alteration to data as some function that takes data as its input, and constructs the altered version as its output.

Equinox and pure functions¤

It might seem a little odd that we contrasted object oriented programming with purely functional programming, and then we kept defining our "pure function" as a class!

It's not really odd, though. What matters is that our classes behave like pure functions because of the way we define __call__. And classes do one very convenient thing for us: they let us keep fixed model parameters (like a) in the same place as a method that defines the model's computation.

This is essentially what eqx.Module is for. And it forces us to code in a functional style. Watch what happens if we try to change one of the attributes of an Equinox module:

class Bar(eqx.Module):
    a: int

my_bar = Bar(a=3)

my_bar.a = 4
---------------------------------------------------------------------------
FrozenInstanceError                       Traceback (most recent call last)
Cell In[27], line 6
      2     a: int
      4 my_bar = Bar(a=3)
----> 6 my_bar.a = 4

File <string>:4, in __setattr__(self, name, value)

FrozenInstanceError: cannot assign to field 'a'

Things are no different if the object tries to change itself directly:

class Baz(eqx.Module):
    a: int

    def __call__(self, x: int):
        self.a = x


my_baz = Baz(a=3)

my_baz(4)
---------------------------------------------------------------------------
FrozenInstanceError                       Traceback (most recent call last)
Cell In[93], line 10
      5         self.a = x
      8 my_baz = Baz(a=3)
---> 10 my_baz(4)

Cell In[93], line 5, in Baz.__call__(self, x)
      4 def __call__(self, x: int):
----> 5     self.a = x

File <string>:4, in __setattr__(self, name, value)

FrozenInstanceError: cannot assign to field 'a'

In other words, Equinox modules are immutable. Immutability goes hand in hand with pure functions, because it ensures that the internal state of our objects cannot be altered in the background.

The model is data, too¤

We've been referring to a as a fixed parameter. Earlier in this example, PureFoo was allowed to change data, but not its own a!

We can still change a. We just need another kind of function, that operates on PureFoo, like PureFoo operated on Data.

def foo_update(foo: PureFoo) -&gt; PureFoo:
    a = foo.a + 1
    return PureFoo(a)


foo = PureFoo(a=2)
foo_new = foo_update(foo)


print(f"Old a: {foo.a}")
print(f"New a: {foo_new.a}")
Old a: 2
New a: 3

This time we've just defined a plain old Python function with def, but we could also have using an Equinox module to define this function, if it had parameters of its own to remember.

JAX arrays are immutable¤

Here's something we can do in NumPy:

import numpy as np

some_array = np.zeros((3, 3))

# Modify the array in-place.
some_array[0, 1:] = 5

some_array
array([[0., 5., 5.],
       [0., 0., 0.],
       [0., 0., 0.]])

The same thing doesn't work in JAX.

import jax.numpy as jnp

some_array = jnp.zeros((3, 3))

some_array[0, 1:] = 5
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[15], line 5
      1 import jax.numpy as jnp
      3 some_array = jnp.zeros((3, 3))
----> 5 some_array[0, 1:] = 5

File ~/.miniforge3/envs/fx/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:285, in _unimplemented_setitem(self, i, x)
    280 def _unimplemented_setitem(self, i, x):
    281   msg = ("'{}' object does not support item assignment. JAX arrays are "
    282          "immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` "
    283          "or another .at[] method: "
    284          "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html")
--> 285   raise TypeError(msg.format(type(self)))

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

JAX arrays are immutable! That means we can't reach in and change an array in-place. We always have to perform some transformation that returns a new array.

JAX provides the at-set syntax for assigning a value to an index.

some_array = some_array.at[0, 1:].set(5)

some_array
Array([[0., 5., 5.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

The right hand side of the assignment can be seen as a function that takes some_array, and returns a new array object with the requested alteration.

In some cases it might appear that JAX performs in-place operations. For example, in NumPy we can do in-place addition like this:

another_array = np.zeros((2, 2))

another_array += 3

another_array
array([[3., 3.],
       [3., 3.]])

We can tell that the operation is in-place because another_array has the same object ID as before:

id_before = id(another_array)

another_array += 10

id_after = id(another_array)

id_before == id_after
True

On the other hand, this is not the case in JAX:

# Now using jnp, not np!
another_array = jnp.zeros((2, 2))

id_before = id(another_array)

another_array += 10

id_after = id(another_array)

id_before == id_after
False

A different ID means a different Python object. In other words, JAX treats another_array += 10 like it would treat the purely functional another_array = another_array + 10, and not as an in-place update.

To be clear, you should write another_array = another_array + 10.

Immutability in Feedbax: performing surgery¤

In Feedbax, models and states can be really big PyTrees. Often we want to change just one part of them. But we don't want to keep writing huge functions that reconstruct the entire model, every time we want to replace just one piece.

Thankfully, Equinox provides a general-purpose function that can perform surgery.

Let's start with a pre-built model.

import jax

from feedbax.xabdeef import point_mass_nn_simple_reaches


context = point_mass_nn_simple_reaches(key=jax.random.PRNGKey(0))
model = context.model  # Shorthand
CUDA backend failed to initialize: Unable to load CUDA. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
/home/mll/.miniforge3/envs/fx/lib/python3.11/site-packages/diffrax/adjoint.py:665: UserWarning: As of Equinox 0.10.7, `equinox.filter_custom_vjp.defvjp` is deprecated in favour of `.def_fwd` and `.def_bwd`. This new API supports symbolic zeros, which allow for more efficient autodifferentiation rules. In particular:
- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.
  _loop_backsolve.defvjp(_loop_backsolve_fwd, _loop_backsolve_bwd)

This model has a point mass of mass \(1.0\) as its skeleton.

model.step.mechanics.plant.skeleton
PointMass(mass=1.0)

As expected, if we try to directly alter the model to use a point mass of mass \(5.0\), an error is raised.

from feedbax.mechanics.skeleton import PointMass

# Try to replace the entire point mass
model.step.mechanics.plant.skeleton = PointMass(5.0)
---------------------------------------------------------------------------

FrozenInstanceError                       Traceback (most recent call last)

Cell In[4], line 4

      1 from feedbax.mechanics.skeleton import PointMass 

      3 # Try to replace the entire point mass

----> 4 model.step.mechanics.plant.skeleton = PointMass(5.0)



File <string>:4, in __setattr__(self, name, value)



FrozenInstanceError: cannot assign to field 'skeleton'
# Or just try to change the mass
model.step.mechanics.plant.skeleton.mass = 5.0
---------------------------------------------------------------------------

FrozenInstanceError                       Traceback (most recent call last)

Cell In[7], line 2

      1 # Try to just change the mass

----> 2 model.step.mechanics.plant.skeleton.mass = 5.0



File <string>:4, in __setattr__(self, name, value)



FrozenInstanceError: cannot assign to field 'mass'

Instead, we use the tree_at function from Equinox. This is a function that takes a PyTree, and replaces a piece of it.

import equinox as eqx

model_heavy = eqx.tree_at(
    lambda m: m.step.mechanics.plant.skeleton,
    model,
    PointMass(5.0)
)

The first argument to tree_at is a "locator" function: when we pass it the model, it returns the part of the model we want to replace.

Lambda functions

Using Python's lambda syntax lets us define the function inline. This isn't strictly necessary, but it's common practice in JAX when we need to define functions to pick out parts of PyTrees.

The second argument is just model, which is the model we want to alter.

The third argument is the replacement part.

Here, the model with the modifications is assigned to model_heavy, and we can still refer to the original model as model.

Random number generation¤

One other way that JAX differs from NumPy is how it handles the generation of random numbers.

In NumPy, random number generators are stateful. We can see this by calling one more than once.

np.random.random()
0.6916782346283932
np.random.random()
0.26970965558017235

The two numbers are different, but the input to the function was not: in both cases, we passed no arguments.

If a function's output differs when the input remains the same, it's not a pure function. In this case, the number changed because it was based on a state variable that changed in the background, like smee did earlier.

These aren't actually random numbers, they're pseudo-random: they're the outputs of a deterministic function that varies wildly with its input. In NumPy, we can control where it starts from by setting the random seed:

seed = 1234

np.random.seed(seed)
np.random.random()
0.1915194503788923
np.random.random()
0.6221087710398319
np.random.seed(seed)
np.random.random()
0.1915194503788923
np.random.random()
0.6221087710398319

Whenever we set the seed, the random numbers start from the same point. This makes our subsequent calls reproducible, in principle. However, because the state of the random number generator continues to change in the background, our code may stop being reproducible if at any point during the execution of our program, some other program makes even a single call to the random number generator, and changes its state.

Note

The situation is similar in PyTorch, and the above example can be repeated with torch.rand and torch.manual_seed.

JAX takes a totally functional and transparent approach to random numbers. Whenever we want to generate a random number, we have to pass a key.

import jax.random as jr

seed = 5678
key = jr.PRNGKey(seed)

jr.uniform(key)
Array(0.376078, dtype=float32)

If we call a random generator a second time with the same key, it returns the same result.

jr.uniform(key)
Array(0.376078, dtype=float32)

When we want to generate a new random number, we get a new key with split.

key1, key2 = jr.split(key)

print(jr.uniform(key1))
print(jr.uniform(key2))
0.6928469
0.040529132

This forces us to always be clear about the logic of how random numbers are generated. In JAX it's typical to see something like:

def generate_data(key):
    key_uniform, key_normal = jr.split(key)
    data_uniform = jr.uniform(key_uniform, (2, 4))
    data_normal = jr.normal(key_normal, (2, 4))
    return {'uniform': data_uniform, 'normal': data_normal}

Normally we only pass a single key to a function, if the function needs to generate random numbers internally. The function defines how many keys it actually needs, by splitting the one we send it.

key = jr.PRNGKey(seed)

generate_data(key)
{'uniform': Array([[0.1093446 , 0.97192943, 0.60279703, 0.5552217 ],
        [0.4859867 , 0.6875596 , 0.18040001, 0.6805732 ]], dtype=float32),
 'normal': Array([[-0.92087466, -0.99356407, -0.01340629,  0.2917211 ],
        [-0.09446456, -0.53876567,  0.04995674, -0.8308685 ]],      dtype=float32)}
# Use the same key for reproducible results.
generate_data(key)
{'uniform': Array([[0.1093446 , 0.97192943, 0.60279703, 0.5552217 ],
        [0.4859867 , 0.6875596 , 0.18040001, 0.6805732 ]], dtype=float32),
 'normal': Array([[-0.92087466, -0.99356407, -0.01340629,  0.2917211 ],
        [-0.09446456, -0.53876567,  0.04995674, -0.8308685 ]],      dtype=float32)}

Compared to random number generation in NumPy and PyTorch, I find that JAX takes a small amount of extra effort—in most cases, writing at most 1 extra line of code per function, to split keys—but I end up feeling significantly more comfortable that my code's output is actually reproducible.