PyTree operations¤
Feedbax collects a few convenient functions which operate on PyTrees, and which are not provided by its dependencies.
Indexing and assignment¤
feedbax.tree_take
(
tree
: PyTree[Any, T]
,
indices
: ArrayLike
,
axis
: int = 0
,
**kwargs
: Any
)
->
PyTree[Any, T]
¤
feedbax.tree_take
(
tree
: PyTree[Any, T]
,
indices
: ArrayLike
,
axis
: int = 0
,
**kwargs
: Any
)
->
PyTree[Any, T]
Indexes elements out of each array leaf of a PyTree.
Any non-array leaves are returned unchanged.
This function inherits the default indexing behaviour of JAX. If out-of-bounds indices are provided, no error will be raised.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree |
PyTree[Any, T]
|
Any PyTree whose array leaves are equivalently indexable,
according to the other arguments to this function. For example,
|
required |
indices |
ArrayLike
|
The indices of the values to take from each array leaf. |
required |
axis |
int
|
The axis of the array leaves over which to take their values. Defaults to 0. |
0
|
kwargs |
Any
|
Additional arguments to |
{}
|
Returns:
Type | Description |
---|---|
PyTree[Any, T]
|
A PyTree with the same structure as |
feedbax.tree_set
(
tree
: PyTree[Any | Shaped[Array, 'batch *?dims'], T]
,
items
: PyTree[Any | Shaped[Array, '*?dims'], T]
,
idx
: int
)
->
PyTree[Any | Shaped[Array, 'batch *?dims'], T]
¤
feedbax.tree_set
(
tree
: PyTree[Any | Shaped[Array, 'batch *?dims'], T]
,
items
: PyTree[Any | Shaped[Array, '*?dims'], T]
,
idx
: int
)
->
PyTree[Any | Shaped[Array, 'batch *?dims'], T]
Perform an out-of-place update of each array leaf of a PyTree.
Non-array leaves are simply replaced by their matching leaves in items
.
For example, if tree
is a PyTree of states over time, whose first dimension
is the time step, and items
is a PyTree of states for a single time step,
this function can be used to insert the latter into the former at a given time index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree |
PyTree[Any | Shaped[Array, 'batch *?dims'], T]
|
Any PyTree whose array leaves share a first dimension of the same length, for example a batch dimension. |
required |
items |
PyTree[Any | Shaped[Array, '*?dims'], T]
|
Any PyTree with the same structure as |
required |
idx |
int
|
The index along the first dimension of the array leaves of |
required |
Returns:
Type | Description |
---|---|
PyTree[Any | Shaped[Array, 'batch *?dims'], T]
|
A PyTree with the same structure as |
Mapping, unzipping, and stacking¤
feedbax.get_ensemble
(
func
: Callable[..., PyTree[Any, S]]
,
*args
: Any
,
n_ensemble
: int
,
key
: PRNGKeyArray
,
**kwargs
: Any
)
->
PyTree[Any, S]
¤
feedbax.get_ensemble
(
func
: Callable[..., PyTree[Any, S]]
,
*args
: Any
,
n_ensemble
: int
,
key
: PRNGKeyArray
,
**kwargs
: Any
)
->
PyTree[Any, S]
Vmap a function over a set of random keys.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[..., PyTree[Any, S]]
|
A function that returns a PyTree, and whose final keyword argument
is |
required |
n_ensemble |
int
|
The number of keys to split; i.e. the size of the batch dimensions in the array leaves of the returned PyTree. |
required |
*args |
Any
|
The positional arguments to |
()
|
key |
PRNGKeyArray
|
The key to split to perform the vmap. |
required |
**kwargs |
Any
|
The keyword arguments to |
{}
|
feedbax.tree_unzip
(
tree
: PyTree[Tuple[Any, ...], T]
)
->
Tuple[PyTree[Any, T], ...]
¤
feedbax.tree_unzip
(
tree
: PyTree[Tuple[Any, ...], T]
)
->
Tuple[PyTree[Any, T], ...]
Unzips a PyTree of tuples into a tuple of PyTrees.
Note
Something similar could be done with tree_transpose
, but outer_treedef
would need to be specified.
This version has zip
-like behaviour, in that 1) the input tree should be
flattenable to tuples, when tuples are treated as leaves; 2) the shortest
of those tuples determines the length of the output.
feedbax.tree_map_tqdm
(
f
: Callable[..., S]
,
tree
: PyTree[Any, T]
,
*rest
: PyTree[Any, T]
,
labels
: Optional[PyTree[str, T]] = None
,
verbose
: bool = False
,
is_leaf
: Optional[Callable[..., bool]] = None
)
->
PyTree[S, T]
¤
feedbax.tree_map_tqdm
(
f
: Callable[..., S]
,
tree
: PyTree[Any, T]
,
*rest
: PyTree[Any, T]
,
labels
: Optional[PyTree[str, T]] = None
,
verbose
: bool = False
,
is_leaf
: Optional[Callable[..., bool]] = None
)
->
PyTree[S, T]
Adds a progress bar to tree_map
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f |
Callable[..., S]
|
The function to map over the tree. |
required |
tree |
PyTree[Any, T]
|
The PyTree to map over. |
required |
*rest |
PyTree[Any, T]
|
Additional arguments to |
()
|
labels |
Optional[PyTree[str, T]]
|
A PyTree of labels for the leaves of |
None
|
is_leaf |
Optional[Callable[..., bool]]
|
A function that returns |
None
|
feedbax.tree_call
(
tree
: PyTree[Any, T]
,
*args
: Any
,
exclude
: Callable = lambda _: False
,
is_leaf
: Optional[Callable] = None
,
**kwargs
: Any
)
->
PyTree[Any, T]
¤
feedbax.tree_call
(
tree
: PyTree[Any, T]
,
*args
: Any
,
exclude
: Callable = lambda _: False
,
is_leaf
: Optional[Callable] = None
,
**kwargs
: Any
)
->
PyTree[Any, T]
Returns a tree of the return values of a PyTree's callable leaves.
Every callable leaf is passed the same *args, **kwargs
.
Non-callable leaves, callable leaves that satisfy exclude
, are passed through
as-is.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree |
PyTree[Any, T]
|
Any PyTree. |
required |
*args |
Any
|
Positional arguments to pass to each callable leaf. |
()
|
exclude |
Callable
|
A function that returns |
lambda _: False
|
**kwargs |
Any
|
Keyword arguments to pass to each callable leaf. |
{}
|
feedbax.tree_stack
(
trees
: Sequence[PyTree[Any, T]]
,
axis
: int = 0
)
->
PyTree[Any, T]
¤
feedbax.tree_stack
(
trees
: Sequence[PyTree[Any, T]]
,
axis
: int = 0
)
->
PyTree[Any, T]
Returns a PyTree whose array leaves stack those of the PyTrees in trees
.
Example
a = [jnp.array([1, 2]), jnp.array([3, 4])]
b = [jnp.array([5, 6]), jnp.array([7, 8])]
tree_stack([a, b], axis=0)
# [jnp.array([[1, 2], [5, 6]]), jnp.array([[3, 4], [7, 8]])]
Derived from this GitHub gist.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
trees |
Sequence[PyTree[Any, T]]
|
A sequence of PyTrees with the same structure, and whose array leaves have the same shape. |
required |
axis |
int
|
The axis along which to stack the array leaves. |
0
|
Leaf labels¤
feedbax.tree_labels
(
tree
: PyTree[Any, T]
,
join_with
: str = '_'
,
is_leaf
: Optional[Callable[..., bool]] = None
)
->
PyTree[str, T]
¤
feedbax.tree_labels
(
tree
: PyTree[Any, T]
,
join_with
: str = '_'
,
is_leaf
: Optional[Callable[..., bool]] = None
)
->
PyTree[str, T]
Return a PyTree of labels based on each leaf's key path.
When tree
is a flat dict:
tree_keys(tree) == {k: str(k) for k in tree.keys()}
When tree
is a flat list:
tree_keys(tree) == [str(i) for i in range(len(tree))]
Verbose tree_map
This function is useful for creating descriptive labels when using tree_map
to apply an expensive operation to a PyTree.
def expensive_op(x):
# Something time-consuming
...
def verbose_expensive_op(leaf, label):
print(f"Processing leaf: {label}")
return expensive_op(leaf)
result = tree_map(
verbose_expensive_op,
tree,
tree_labels(tree),
)
A similar use case combines this function with
tree_map_tqdm
to label a progress bar:
result = tree_map_tqdm(
expensive_op,
tree,
labels=tree_labels(tree),
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree |
PyTree[Any, T]
|
The PyTree for which to generate labels. |
required |
join_with |
str
|
The string with which to join a leaf's path keys, to form its label. |
'_'
|
is_leaf |
Optional[Callable[..., bool]]
|
An optional function that returns a boolean, which determines whether each
node in |
None
|
Random keys¤
feedbax.random_split_like_tree
(
key
: PRNGKeyArray
,
tree
: PyTree[Any, T]
,
is_leaf
: Optional[Callable[[Any], bool]] = None
)
->
PyTree[PRNGKeyArray | None, T]
¤
feedbax.random_split_like_tree
(
key
: PRNGKeyArray
,
tree
: PyTree[Any, T]
,
is_leaf
: Optional[Callable[[Any], bool]] = None
)
->
PyTree[PRNGKeyArray | None, T]
Returns a split of random keys, as leaves of a target PyTree structure.
Derived from this comment on a discussion in the JAX GitHub repository.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray
|
The random key from which to split the tree of random keys. |
required |
tree |
PyTree[Any, T]
|
Any PyTree. |
required |
is_leaf |
Optional[Callable[[Any], bool]]
|
An optional function that decides whether each node in |
None
|
Memory usage¤
feedbax.tree_array_bytes
(
tree
: PyTree
,
duplicates
: bool = False
)
->
int
¤
feedbax.tree_array_bytes
(
tree
: PyTree
,
duplicates
: bool = False
)
->
int
Returns the total bytes of memory over all array leaves of a PyTree.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tree |
PyTree
|
The tree with arrays to measure. |
required |
duplicates |
bool
|
If |
False
|
feedbax.tree_struct_bytes
(
tree
: PyTree[jax.ShapeDtypeStruct]
)
->
int
¤
feedbax.tree_struct_bytes
(
tree
: PyTree[jax.ShapeDtypeStruct]
)
->
int
Returns the total bytes of memory implied by a PyTree of ShapeDtypeStruct
s.