xlstm_jax.utils.pytree_utils#

Attributes#

Classes#

Functions#

pytree_diff(tree1, tree2)

Computes the difference between two PyTrees.

pytree_key_path_to_str(path[, separator])

Converts a path to a string.

flatten_pytree(pytree[, separator, is_leaf])

Flattens a PyTree into a dict.

flatten_dict(d[, separator])

Flattens a nested dictionary.

get_shape_dtype_pytree(x)

Converts a PyTree of jax.Array objects to a PyTree of ShapeDtypeStruct objects.

delete_arrays_in_pytree(x)

Deletes and frees all jax.Array objects in a PyTree from the device memory.

Module Contents#

xlstm_jax.utils.pytree_utils.LOGGER#
class xlstm_jax.utils.pytree_utils.RecursionLimit#
xlstm_jax.utils.pytree_utils.pytree_diff(tree1, tree2)#

Computes the difference between two PyTrees.

Parameters:
  • tree1 (xlstm_jax.common_types.PyTree) – First PyTree.

  • tree2 (xlstm_jax.common_types.PyTree) – Second PyTree.

Returns:

A PyTree of the same structure, with only differing leaves. Returns None if no differences are found.

Return type:

xlstm_jax.common_types.PyTree

>>> pytree_diff({"a": 1}, {"a": 2})
{'a': (1, 2)}
>>> pytree_diff({"a": 1}, {"a": 1})
>>> pytree_diff([1, 2, 3], [1, 2])
{'length_mismatch': (3, 2)}
>>> pytree_diff(np.array([1, 2, 3]), np.array([1, 2]))
{'shape_mismatch': ((3,), (2,))}
xlstm_jax.utils.pytree_utils.pytree_key_path_to_str(path, separator='.')#

Converts a path to a string.

An adjusted version of jax.tree_util.keystr to support different separators and easier to read output.

Parameters:
  • path (jax.tree_util.KeyPath) – Path.

  • separator (str) – Separator for the keys.

Returns:

Path as string.

Return type:

str

xlstm_jax.utils.pytree_utils.flatten_pytree(pytree, separator='.', is_leaf=None)#

Flattens a PyTree into a dict.

Supports PyTrees with nested dictionaries, lists, tuples, and more. The keys are created by concatenating the path to the leaf with the separator. For sequences, the index is used as key (see examples below).

Parameters:
  • pytree (xlstm_jax.common_types.PyTree) – PyTree to be flattened.

  • separator (str) – Separator for the keys.

  • is_leaf (collections.abc.Callable[[Any], bool] | None) – Function that determines if a node is a leaf. If None, uses default PyTree leaf detection.

Returns:

Flattened PyTree. In case of duplicate keys, a ValueError is raised.

Return type:

dict

>>> flatten_pytree({"a": 1, "b": {"c": 2}})
{'a': 1, 'b.c': 2}
>>> flatten_pytree({"a": 1, "b": (2, 3, 4)}, separator="/")
{'a': 1, 'b/0': 2, 'b/1': 3, 'b/2': 4}
>>> flatten_pytree(("a", "b", "c"))
{'0': 'a', '1': 'b', '2': 'c'}
xlstm_jax.utils.pytree_utils.flatten_dict(d, separator='.')#

Flattens a nested dictionary.

In contrast to flatten_pytree, this function is specifically designed for dictionaries and does not flatten sequences by default. It is equivalent to setting the is_leaf function in flatten_pytree to: flatten_pytree(d, is_leaf=lambda x: not isinstance(x, (dict, FrozenDict))).

Parameters:
  • d (dict | flax.core.FrozenDict) – Dictionary to be flattened.

  • separator (str) – Separator for the keys.

Returns:

Flattened dictionary.

Return type:

dict

>>> flatten_dict({"a": {"b": 1}, "c": (2, 3, 4)})
{'a.b': 1, 'c': (2, 3, 4)}
xlstm_jax.utils.pytree_utils.get_shape_dtype_pytree(x)#

Converts a PyTree of jax.Array objects to a PyTree of ShapeDtypeStruct objects.

Leaf nodes of the PyTree that are not jax.Array objects are left unchanged.

Parameters:

x (xlstm_jax.common_types.PyTree) – PyTree of jax.Array objects.

Returns:

PyTree of ShapeDtypeStruct objects.

Return type:

xlstm_jax.common_types.PyTree

xlstm_jax.utils.pytree_utils.delete_arrays_in_pytree(x)#

Deletes and frees all jax.Array objects in a PyTree from the device memory.

Leaf nodes of the PyTree that are not jax.Array objects are left unchanged.

Parameters:

x (xlstm_jax.common_types.PyTree) – PyTree of jax.Array objects.

Return type:

None