Skip to content

PyTree operations¤

Feedbax collects a few convenient functions which operate on PyTrees, and which are not provided by its dependencies.

Indexing and assignment¤

feedbax.tree_take (tree: PyTree[Any, T],indices: ArrayLike,axis: int = 0,**kwargs: Any) -> PyTree[Any, T]
¤

Indexes elements out of each array leaf of a PyTree.

Any non-array leaves are returned unchanged.

This function inherits the default indexing behaviour of JAX. If out-of-bounds indices are provided, no error will be raised.

Parameters:

Name Type Description Default
tree PyTree[Any, T]

Any PyTree whose array leaves are equivalently indexable, according to the other arguments to this function. For example, axis=0 could be used when the first dimension of every array leaf is a batch dimension, and indices specifies a subset of examples from the batch.

required
indices ArrayLike

The indices of the values to take from each array leaf.

required
axis int

The axis of the array leaves over which to take their values. Defaults to 0.

0
kwargs Any

Additional arguments to jax.numpy.take.

{}

Returns:

Type Description
PyTree[Any, T]

A PyTree with the same structure as tree, where array leaves from tree have been replaced by indexed-out elements.

feedbax.tree_set (tree: PyTree[Any | Shaped[Array, 'batch *?dims'], T],items: PyTree[Any | Shaped[Array, '*?dims'], T],idx: int) -> PyTree[Any | Shaped[Array, 'batch *?dims'], T]
¤

Perform an out-of-place update of each array leaf of a PyTree.

Non-array leaves are simply replaced by their matching leaves in items.

For example, if tree is a PyTree of states over time, whose first dimension is the time step, and items is a PyTree of states for a single time step, this function can be used to insert the latter into the former at a given time index.

Parameters:

Name Type Description Default
tree PyTree[Any | Shaped[Array, 'batch *?dims'], T]

Any PyTree whose array leaves share a first dimension of the same length, for example a batch dimension.

required
items PyTree[Any | Shaped[Array, '*?dims'], T]

Any PyTree with the same structure as tree, and whose array leaves have the same shape as the corresponding leaves in tree, but lacking the first dimension.

required
idx int

The index along the first dimension of the array leaves of tree into which to insert the array leaves of items.

required

Returns:

Type Description
PyTree[Any | Shaped[Array, 'batch *?dims'], T]

A PyTree with the same structure as tree, where the array leaves of items have been inserted as the idx-th elements of the corresponding array leaves of tree.

Mapping, unzipping, and stacking¤

feedbax.get_ensemble (func: Callable[..., PyTree[Any, S]],*args: Any,n_ensemble: int,key: PRNGKeyArray,**kwargs: Any) -> PyTree[Any, S]
¤

Vmap a function over a set of random keys.

Parameters:

Name Type Description Default
func Callable[..., PyTree[Any, S]]

A function that returns a PyTree, and whose final keyword argument is key: PRNGKeyArray.

required
n_ensemble int

The number of keys to split; i.e. the size of the batch dimensions in the array leaves of the returned PyTree.

required
*args Any

The positional arguments to func.

()
key PRNGKeyArray

The key to split to perform the vmap.

required
**kwargs Any

The keyword arguments to func.

{}

feedbax.tree_unzip (tree: PyTree[Tuple[Any, ...], T]) -> Tuple[PyTree[Any, T], ...]
¤

Unzips a PyTree of tuples into a tuple of PyTrees.

Note

Something similar could be done with tree_transpose, but outer_treedef would need to be specified.

This version has zip-like behaviour, in that 1) the input tree should be flattenable to tuples, when tuples are treated as leaves; 2) the shortest of those tuples determines the length of the output.

feedbax.tree_map_tqdm (f: Callable[..., S],tree: PyTree[Any, T],*rest: PyTree[Any, T],labels: Optional[PyTree[str, T]] = None,verbose: bool = False,is_leaf: Optional[Callable[..., bool]] = None) -> PyTree[S, T]
¤

Adds a progress bar to tree_map.

Parameters:

Name Type Description Default
f Callable[..., S]

The function to map over the tree.

required
tree PyTree[Any, T]

The PyTree to map over.

required
*rest PyTree[Any, T]

Additional arguments to f, as PyTrees with the same structure as tree.

()
labels Optional[PyTree[str, T]]

A PyTree of labels for the leaves of tree, to be displayed on the progress bar.

None
is_leaf Optional[Callable[..., bool]]

A function that returns True for leaves of tree.

None

feedbax.tree_call (tree: PyTree[Any, T],*args: Any,exclude: Callable = lambda _: False,is_leaf: Optional[Callable] = None,**kwargs: Any) -> PyTree[Any, T]
¤

Returns a tree of the return values of a PyTree's callable leaves.

Every callable leaf is passed the same *args, **kwargs.

Non-callable leaves, callable leaves that satisfy exclude, are passed through as-is.

Parameters:

Name Type Description Default
tree PyTree[Any, T]

Any PyTree.

required
*args Any

Positional arguments to pass to each callable leaf.

()
exclude Callable

A function that returns True for any callable leaf that should not be called.

lambda _: False
**kwargs Any

Keyword arguments to pass to each callable leaf.

{}

feedbax.tree_stack (trees: Sequence[PyTree[Any, T]],axis: int = 0) -> PyTree[Any, T]
¤

Returns a PyTree whose array leaves stack those of the PyTrees in trees.

Example

a = [jnp.array([1, 2]), jnp.array([3, 4])]
b = [jnp.array([5, 6]), jnp.array([7, 8])]

tree_stack([a, b], axis=0)
# [jnp.array([[1, 2], [5, 6]]), jnp.array([[3, 4], [7, 8]])]

Derived from this GitHub gist.

Parameters:

Name Type Description Default
trees Sequence[PyTree[Any, T]]

A sequence of PyTrees with the same structure, and whose array leaves have the same shape.

required
axis int

The axis along which to stack the array leaves.

0

Leaf labels¤

feedbax.tree_labels (tree: PyTree[Any, T],join_with: str = '_',is_leaf: Optional[Callable[..., bool]] = None) -> PyTree[str, T]
¤

Return a PyTree of labels based on each leaf's key path.

When tree is a flat dict:

tree_keys(tree) == {k: str(k) for k in tree.keys()}

When tree is a flat list:

tree_keys(tree) == [str(i) for i in range(len(tree))]

Verbose tree_map

This function is useful for creating descriptive labels when using tree_map to apply an expensive operation to a PyTree.

def expensive_op(x):
    # Something time-consuming
    ...

def verbose_expensive_op(leaf, label):
    print(f"Processing leaf: {label}")
    return expensive_op(leaf)

result = tree_map(
    verbose_expensive_op,
    tree,
    tree_labels(tree),
)

A similar use case combines this function with tree_map_tqdm to label a progress bar:

result = tree_map_tqdm(
    expensive_op,
    tree,
    labels=tree_labels(tree),
)

Parameters:

Name Type Description Default
tree PyTree[Any, T]

The PyTree for which to generate labels.

required
join_with str

The string with which to join a leaf's path keys, to form its label.

'_'
is_leaf Optional[Callable[..., bool]]

An optional function that returns a boolean, which determines whether each node in tree should be treated as a leaf.

None

Random keys¤

feedbax.random_split_like_tree (key: PRNGKeyArray,tree: PyTree[Any, T],is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTree[PRNGKeyArray | None, T]
¤

Returns a split of random keys, as leaves of a target PyTree structure.

Derived from this comment on a discussion in the JAX GitHub repository.

Parameters:

Name Type Description Default
key PRNGKeyArray

The random key from which to split the tree of random keys.

required
tree PyTree[Any, T]

Any PyTree.

required
is_leaf Optional[Callable[[Any], bool]]

An optional function that decides whether each node in tree should be treated as a leaf, or traversed as a subtree.

None

Memory usage¤

feedbax.tree_array_bytes (tree: PyTree,duplicates: bool = False) -> int
¤

Returns the total bytes of memory over all array leaves of a PyTree.

Parameters:

Name Type Description Default
tree PyTree

The tree with arrays to measure.

required
duplicates bool

If False, then leaves that refer to the same array in memory will only be counted once.

False

feedbax.tree_struct_bytes (tree: PyTree[jax.ShapeDtypeStruct]) -> int
¤

Returns the total bytes of memory implied by a PyTree of ShapeDtypeStructs.