# Copyright 2023 sepes authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for pytrees."""
from __future__ import annotations
import copy
import functools as ft
import operator as op
from math import ceil, floor, trunc
from typing import Any, Callable, Generic, Hashable, Iterator, Sequence, Tuple, TypeVar
from typing_extensions import ParamSpec
import sepes
import sepes._src.backend.arraylib as arraylib
from sepes._src.backend import is_package_avaiable
T = TypeVar("T")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
P = ParamSpec("P")
PyTree = Any
EllipsisType = TypeVar("EllipsisType")
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
TypeEntry = TypeVar("TypeEntry", bound=type)
TraceEntry = Tuple[KeyEntry, TypeEntry]
KeyPath = Tuple[KeyEntry, ...]
TypePath = Tuple[TypeEntry, ...]
KeyTypePath = Tuple[KeyPath, TypePath]
def tree_hash(*trees: PyTree) -> int:
treelib = sepes._src.backend.treelib
leaves, treedef = treelib.flatten(trees)
return hash((*leaves, treedef))
def tree_copy(tree: T) -> T:
"""Return a copy of the tree."""
# the dispatcher calls copy on the leaves of the tree
# by default as an extra measure - beside flatten/unflatten-
# to ensure that the tree is copied completely
treelib = sepes._src.backend.treelib
types = tuple(set(tree_copy.copy_dispatcher.registry) - {object})
def is_leaf(node) -> bool:
return isinstance(node, types)
return treelib.map(tree_copy.copy_dispatcher, tree, is_leaf=is_leaf)
# default behavior is to copy the tree elements except for registered types
# like jax arrays which are immutable by default and should not be copied
tree_copy.copy_dispatcher = ft.singledispatch(copy.copy)
tree_copy.def_type = tree_copy.copy_dispatcher.register
@tree_copy.def_type(int)
@tree_copy.def_type(float)
@tree_copy.def_type(complex)
@tree_copy.def_type(str)
@tree_copy.def_type(bytes)
def _(x: T) -> T:
# skip applying `copy.copy` on immutable atom types
return x
def is_array_like(node) -> bool:
return hasattr(node, "shape") and hasattr(node, "dtype")
def _is_leaf_rhs_equal(leaf, rhs):
if is_array_like(leaf):
if is_array_like(rhs):
if leaf.shape != rhs.shape:
return False
if leaf.dtype != rhs.dtype:
return False
try:
verdict = arraylib.all(leaf == rhs)
except NotImplementedError:
verdict = leaf == rhs
try:
return bool(verdict)
except Exception:
return verdict # fail under `jit`
return False
return leaf == rhs
def is_tree_equal(*trees: Any) -> bool:
"""Return ``True`` if all pytrees are equal.
Note:
trees are compared using their leaves and treedefs.
"""
treelib = sepes._src.backend.treelib
tree0, *rest = trees
leaves0, treedef0 = treelib.flatten(tree0)
verdict = True
for tree in rest:
leaves, treedef = treelib.flatten(tree)
if (treedef != treedef0) or verdict is False:
return False
verdict = ft.reduce(op.and_, map(_is_leaf_rhs_equal, leaves0, leaves), verdict)
return verdict
class Static(Generic[T]):
def __init_subclass__(klass, **k) -> None:
# register subclasses as an empty pytree node
# written like this to enforce selection of the proper backend
# every time a subclass is created
super().__init_subclass__(**k)
# register with the proper backend
treelib = sepes._src.backend.treelib
treelib.register_static(klass)
class partial(ft.partial):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
iargs = iter(args)
args = (next(iargs) if arg is ... else arg for arg in self.args) # type: ignore
return self.func(*args, *iargs, **{**self.keywords, **kwargs})
[docs]
def bcmap(
func: Callable[P, T],
broadcast_to: int | str | None = None,
*,
is_leaf: Callable[[Any], bool] | None = None,
) -> Callable[P, T]:
"""Map a function over pytree leaves with automatic broadcasting for scalar arguments.
Args:
func: the function to be mapped over the pytree.
broadcast_to: Accepts integer for broadcasting to a specific argument
or string for broadcasting to a specific keyword argument.
If ``None``, then the function is broadcasted to the first argument
or the first keyword argument if no positional arguments are provided.
Defaults to ``None``.
is_leaf: a predicate function that returns True if the node is a leaf.
Example:
Transform `numpy` functions to work with pytrees:
>>> import sepes as sp
>>> import jax.numpy as jnp
>>> tree_of_arrays = {"a": jnp.array([1, 2, 3]), "b": jnp.array([4, 5, 6])}
>>> tree_add = sp.bcmap(jnp.add)
>>> # both lhs and rhs are pytrees
>>> print(sp.tree_str(tree_add(tree_of_arrays, tree_of_arrays)))
dict(a=[2 4 6], b=[ 8 10 12])
>>> # rhs is a scalar
>>> print(sp.tree_str(tree_add(tree_of_arrays, 1)))
dict(a=[2 3 4], b=[5 6 7])
"""
treelib = sepes._src.backend.treelib
@ft.wraps(func)
def wrapper(*args, **kwargs):
cargs = []
ckwargs = {}
leaves = []
kwargs_keys: list[str] = []
bdcst_to = (
(0 if len(args) else next(iter(kwargs)))
if broadcast_to is None
else broadcast_to
)
treedef0 = (
# reference treedef is the first positional argument
treelib.flatten(args[bdcst_to], is_leaf=is_leaf)[1]
if len(args)
# reference treedef is the first keyword argument
else treelib.flatten(kwargs[bdcst_to], is_leaf=is_leaf)[1]
)
for arg in args:
if treedef0 == treelib.flatten(arg, is_leaf=is_leaf)[1]:
cargs += [...]
leaves += [treedef0.flatten_up_to(arg)]
else:
cargs += [arg]
for key in kwargs:
if treedef0 == treelib.flatten(kwargs[key], is_leaf=is_leaf)[1]:
ckwargs[key] = ...
leaves += [treedef0.flatten_up_to(kwargs[key])]
kwargs_keys += [key]
else:
ckwargs[key] = kwargs[key]
split_index = len(leaves) - len(kwargs_keys)
all_leaves = []
bfunc = partial(func, *cargs, **ckwargs)
for args_kwargs_values in zip(*leaves):
args = args_kwargs_values[:split_index]
kwargs = dict(zip(kwargs_keys, args_kwargs_values[split_index:]))
all_leaves += [bfunc(*args, **kwargs)]
return treelib.unflatten(treedef0, all_leaves)
return wrapper
def swop(func):
# swaping the arguments of a two-arg function
return ft.wraps(func)(lambda leaf, rhs: func(rhs, leaf))
[docs]
def leafwise(klass: type[T]) -> type[T]:
"""A class decorator that adds leafwise operators to a class.
Leafwise operators are operators that are applied to the leaves of a pytree.
For example leafwise ``__add__`` is equivalent to:
- ``tree_map(lambda x: x + rhs, tree)`` if ``rhs`` is a scalar.
- ``tree_map(lambda x, y: x + y, tree, rhs)`` if ``rhs`` is a pytree
with the same structure as ``tree``.
Args:
klass: The class to be decorated.
Returns:
The decorated class.
Example:
Use ``numpy`` functions on :class:`TreeClass`` classes decorated with :func:`leafwise`
>>> import sepes as sp
>>> import jax.numpy as jnp
>>> @sp.leafwise
... @sp.autoinit
... class Point(sp.TreeClass):
... x: float = 0.5
... y: float = 1.0
... description: str = "point coordinates"
>>> # use :func:`tree_mask` to mask the non-inexact part of the tree
>>> # i.e. mask the string leaf ``description`` to ``Point`` work
>>> # with ``jax.numpy`` functions
>>> co = sp.tree_mask(Point())
>>> print(sp.bcmap(jnp.where)(co > 0.5, co, 1000))
Point(x=1000.0, y=1.0, description=#point coordinates)
Note:
If a mathematically equivalent operator is already defined on the class,
then it is not overridden.
================== ============
Method Operator
================== ============
``__add__`` ``+``
``__and__`` ``&``
``__ceil__`` ``math.ceil``
``__divmod__`` ``divmod``
``__eq__`` ``==``
``__floor__`` ``math.floor``
``__floordiv__`` ``//``
``__ge__`` ``>=``
``__gt__`` ``>``
``__invert__`` ``~``
``__le__`` ``<=``
``__lshift__`` ``<<``
``__lt__`` ``<``
``__matmul__`` ``@``
``__mod__`` ``%``
``__mul__`` ``*``
``__ne__`` ``!=``
``__neg__`` ``-``
``__or__`` ``|``
``__pos__`` ``+``
``__pow__`` ``**``
``__round__`` ``round``
``__sub__`` ``-``
``__truediv__`` ``/``
``__trunc__`` ``math.trunc``
``__xor__`` ``^``
================== ============
"""
treelib = sepes._src.backend.treelib
def uop(func):
def wrapper(self):
return treelib.map(func, self)
return ft.wraps(func)(wrapper)
def bop(func):
def wrapper(leaf, rhs=None):
if isinstance(rhs, type(leaf)):
return treelib.map(func, leaf, rhs)
return treelib.map(lambda x: func(x, rhs), leaf)
return ft.wraps(func)(wrapper)
for key, method in (
("__abs__", uop(abs)),
("__add__", bop(op.add)),
("__and__", bop(op.and_)),
("__ceil__", uop(ceil)),
("__divmod__", bop(divmod)),
("__eq__", bop(op.eq)),
("__floor__", uop(floor)),
("__floordiv__", bop(op.floordiv)),
("__ge__", bop(op.ge)),
("__gt__", bop(op.gt)),
("__invert__", uop(op.invert)),
("__le__", bop(op.le)),
("__lshift__", bop(op.lshift)),
("__lt__", bop(op.lt)),
("__matmul__", bop(op.matmul)),
("__mod__", bop(op.mod)),
("__mul__", bop(op.mul)),
("__ne__", bop(op.ne)),
("__neg__", uop(op.neg)),
("__or__", bop(op.or_)),
("__pos__", uop(op.pos)),
("__pow__", bop(op.pow)),
("__radd__", bop(swop(op.add))),
("__rand__", bop(swop(op.and_))),
("__rdivmod__", bop(swop(divmod))),
("__rfloordiv__", bop(swop(op.floordiv))),
("__rlshift__", bop(swop(op.lshift))),
("__rmatmul__", bop(swop(op.matmul))),
("__rmod__", bop(swop(op.mod))),
("__rmul__", bop(swop(op.mul))),
("__ror__", bop(swop(op.or_))),
("__round__", bop(round)),
("__rpow__", bop(swop(op.pow))),
("__rrshift__", bop(swop(op.rshift))),
("__rshift__", bop(op.rshift)),
("__rsub__", bop(swop(op.sub))),
("__rtruediv__", bop(swop(op.truediv))),
("__rxor__", bop(swop(op.xor))),
("__sub__", bop(op.sub)),
("__truediv__", bop(op.truediv)),
("__trunc__", uop(trunc)),
("__xor__", bop(op.xor)),
):
if key not in vars(klass):
# do not override any user defined methods
# this behavior similar is to `dataclasses.dataclass`
setattr(klass, key, method)
return klass
def tree_type_path_leaves(
tree: PyTree,
*,
is_leaf: Callable[[Any], bool] | None = None,
is_path_leaf: Callable[[KeyTypePath], bool] | None = None,
) -> Sequence[tuple[KeyTypePath, Any]]:
treelib = sepes._src.backend.treelib
_, atomicdef = treelib.flatten(1)
# mainly used for visualization
def flatten_one_level(type_path: KeyTypePath, tree: PyTree):
# predicate and type path
if (is_leaf and is_leaf(tree)) or (is_path_leaf and is_path_leaf(type_path)):
yield type_path, tree
return
def one_level_is_leaf(node) -> bool:
if is_leaf and is_leaf(node):
return True
if id(node) == id(tree):
return False
return True
path_leaf, treedef = treelib.path_flatten(tree, is_leaf=one_level_is_leaf)
if treedef == atomicdef:
yield type_path, tree
return
for key, value in path_leaf:
keys, types = type_path
path = ((*keys, *key), (*types, type(value)))
yield from flatten_one_level(path, value)
return list(flatten_one_level(((), ()), tree))
class Node:
# mainly used for visualization
__slots__ = ["data", "parent", "children", "__weakref__"]
def __init__(
self,
data: tuple[TraceEntry, Any],
parent: Node | None = None,
):
self.data = data
self.parent = parent
self.children: dict[TraceEntry, Node] = {}
def add_child(self, child: Node) -> None:
# add child node to this node and set
# this node as the parent of the child
if not isinstance(child, Node):
raise TypeError(f"`child` must be a `Node`, got {type(child)}")
ti, _ = child.data
if ti not in self.children:
# establish parent-child relationship
child.parent = self
self.children[ti] = child
def __iter__(self) -> Iterator[Node]:
# iterate over children nodes
return iter(self.children.values())
def __repr__(self) -> str:
return f"Node(data={self.data})"
def __contains__(self, key: TraceEntry) -> bool:
return key in self.children
def is_path_leaf_depth_factory(depth: int | float):
# generate `is_path_leaf` function to stop tracing at a certain `depth`
# in essence, depth is the length of the trace entry
def is_path_leaf(trace) -> bool:
keys, _ = trace
# stop tracing if depth is reached
return False if depth is None else (depth <= len(keys))
return is_path_leaf
def construct_tree(
tree: PyTree,
is_leaf: Callable[[Any], bool] | None = None,
is_path_leaf: Callable[[KeyTypePath], bool] | None = None,
) -> Node:
# construct a tree with `Node` objects using `tree_type_path_leaves`
# to establish parent-child relationship between nodes
traces_leaves = tree_type_path_leaves(
tree,
is_leaf=is_leaf,
is_path_leaf=is_path_leaf,
)
ti = (None, type(tree))
vi = tree
root = Node(data=(ti, vi))
for trace, leaf in traces_leaves:
keys, types = trace
cur = root
for i, ti in enumerate(zip(keys, types)):
if ti in cur:
# common parent node
cur = cur.children[ti]
else:
# new path
vi = leaf if i == len(keys) - 1 else None
child = Node(data=(ti, vi))
cur.add_child(child)
cur = child
return root
[docs]
def value_and_tree(func: Callable[..., T], argnums: int | Sequence[int] = 0):
"""Call a function on copied input argument and return the value and the tree.
Input arguments are copied before calling the function, and the argument
specified by ``argnums`` are returned as a tree.
Args:
func: A function.
argnums: The argument number of the tree that will be returned. If multiple
arguments are specified, the tree will be returned as a tuple.
Returns:
A function that returns the value and the tree.
Example:
Usage with mutable types:
>>> import sepes as sp
>>> mutable_tree = [1, 2, 3]
>>> def mutating_func(tree):
... tree[0] += 100
... return tree
>>> new_tree = mutating_func(mutable_tree)
>>> assert new_tree is mutable_tree
>>> # now with `value_and_tree` the function does not mutate the tree
>>> new_tree, _ = sp.value_and_tree(mutating_func)(mutable_tree)
>>> assert new_tree is not mutable_tree
Example:
Usage with immutable types (:class:`.TreeClass`) with support for in-place
mutation via custom behavior registration using :func:`.value_and_tree.def_mutator`
and :func:`.value_and_tree.def_immutator`:
>>> import sepes as sp
>>> class Counter(sp.TreeClass):
... def __init__(self, count: int):
... self.count = count
... def increment(self, value):
... self.count += value
... return self.count
>>> counter = Counter(0)
>>> counter.increment(1) # doctest: +SKIP
AttributeError: Cannot set attribute value=1 to `key='count'` on an immutable instance of `Counter`.
>>> sp.value_and_tree(lambda counter: counter.increment(1))(counter)
(1, Counter(count=1))
Note:
Use this function on function that:
- Mutates the input arguments of mutable types (e.g. lists, dicts, etc.).
- Mutates the input arguments of immutable types that do not support in-place
mutation and needs special handling that can be registered (e.g. :class:`.TreeClass`)
using :func:`.value_and_tree.def_mutator` and :func:`.value_and_tree.def_immutator`.
Note:
The default behavior of :func:`value_and_tree` is to copy the input
arguments and then call the function on the copy. However if the function
mutates some of the input arguments that does not support in-place mutation,
then the function will fail. In this case, :func:`value_and_tree` enables
registering custom behavior that modifies the copied input argument to
allow in-place mutation. and custom function that restores the copied
argument to its original state after the method call. The following example
shows how to register custom functions for a simple class that allows
in-place mutation if ``immutable`` Flag is set to ``False``.
>>> import jax
>>> from jax.util import unzip2
>>> import sepes as sp
>>> @jax.tree_util.register_pytree_node_class
... class MyNode:
... def __init__(self):
... self.counter = 0
... self.immutable = True
... def tree_flatten(self):
... keys, values = unzip2(vars(self).items())
... return tuple(values), tuple(keys)
... @classmethod
... def tree_unflatten(cls, keys, values):
... self = object.__new__(cls)
... vars(self).update(dict(zip(keys, values)))
... return self
... def __setattr__(self, name, value):
... if getattr(self, "immutable", False) is True:
... raise AttributeError("MyNode is immutable")
... object.__setattr__(self, name, value)
... def __repr__(self):
... params = ", ".join(f"{k}={v}" for k, v in vars(self).items())
... return f"MyNode({params})"
... def increment(self) -> None:
... self.counter += 1
>>> @sp.value_and_tree.def_mutator(MyNode)
... def mutable(node) -> None:
... vars(node)["immutable"] = False
>>> @sp.value_and_tree.def_immutator(MyNode)
... def immutable(node) -> None:
... vars(node)["immutable"] = True
>>> node = MyNode()
>>> sp.value_and_tree(lambda node: node.increment())(node)
(None, MyNode(counter=1, immutable=True))
"""
treelib = sepes._src.backend.treelib
is_int_argnum = isinstance(argnums, int)
argnums = [argnums] if is_int_argnum else argnums
def mutate_is_leaf(node):
value_and_tree.mutator_dispatcher(node)
return False
def immutate_is_leaf(node):
value_and_tree.immutator_dispatcher(node)
return False
@ft.wraps(func)
def stateless_func(*args, **kwargs) -> tuple[T, PyTree | tuple[PyTree, ...]]:
# copy the incoming inputs
(args, kwargs) = tree_copy((args, kwargs))
# and edit the node/record to make it mutable (if there is a rule for it)
treelib.map(lambda _: _, (args, kwargs), is_leaf=mutate_is_leaf)
output = func(*args, **kwargs)
# traverse each node in the tree depth-first manner
# to undo the mutation (if there is a rule for it)
treelib.map(lambda _: _, (args, kwargs), is_leaf=immutate_is_leaf)
out_args = tuple(a for i, a in enumerate(args) if i in argnums)
out_args = out_args[0] if is_int_argnum else out_args
return output, out_args
return stateless_func
value_and_tree.mutator_dispatcher = ft.singledispatch(lambda node: node)
value_and_tree.immutator_dispatcher = ft.singledispatch(lambda node: node)
value_and_tree.def_mutator = value_and_tree.mutator_dispatcher.register
value_and_tree.def_immutator = value_and_tree.immutator_dispatcher.register
if is_package_avaiable("jax"):
import jax
# basically avoid calling copy on jax arrays because they
# are immutable by default
@tree_copy.def_type(jax.Array)
def _(node: jax.Array) -> jax.Array:
return node
# avoid calling __copy__ on jitted functions becasue they loses their
# wrapped function attributes (maybe a bug in jax)
@tree_copy.def_type(type(jax.jit(lambda x: x)))
def _(node: T1) -> T1:
return node