Source code for sepes._src.tree_mask

# Copyright 2023 sepes authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to work with non-inexact type tree leaves across function transformations."""

from __future__ import annotations

import functools as ft
import hashlib
from typing import Any, Callable, NamedTuple, TypeVar, Union

import sepes
import sepes._src.backend.arraylib as arraylib
from sepes._src.backend import is_package_avaiable
from sepes._src.tree_pprint import tree_repr, tree_str, tree_summary
from sepes._src.tree_util import Static, is_tree_equal, tree_copy, tree_hash

T = TypeVar("T")
MaskType = Union[T, Callable[[Any], bool]]


def is_nondiff(value: Any) -> bool:
    return is_nondiff.type_dispatcher(value)


is_nondiff.type_dispatcher = ft.singledispatch(lambda _: True)
is_nondiff.def_type = is_nondiff.type_dispatcher.register


for ndarray in arraylib.ndarrays:

    @is_nondiff.def_type(ndarray)
    def is_nondiff_array(value) -> bool:
        # return True if the node is non-inexact type, otherwise False
        if arraylib.is_inexact(value):
            return False
        return True


@is_nondiff.def_type(float)
@is_nondiff.def_type(complex)
def _(_: float | complex) -> bool:
    return False


class _MaskedError(NamedTuple):
    opname: str

    def __call__(self, *a, **k):
        raise NotImplementedError(
            f"Cannot apply `{self.opname}` operation on a masked object "
            f"{', '.join(map(str, a))} "
            f"{', '.join(k + '=' + str(v) for k, v in k.items())}.\n"
            "Unmask the object first using `tree_unmask`"
        )


class _MaskBase(Static[T]):
    # the objective of this class is to wrap a pytree node with a custom wrapper
    # that yields no leaves when flattened. This is useful to avoid updating
    # the node by effectivly *hiding it* from function transformations that operates
    # on flattened pytrees.
    __slots__ = ["__wrapped__"]
    __wrapped__: T

    def __init__(self, node: T) -> None:
        object.__setattr__(self, "__wrapped__", node)

    def __setattr__(self, _, __) -> None:
        raise AttributeError("Cannot assign to frozen instance.")

    def __delattr__(self, _: str) -> None:
        raise AttributeError("Cannot delete from frozen instance.")

    def __repr__(self) -> str:
        return "#" + tree_repr(self.__wrapped__)

    def __str__(self) -> str:
        return "#" + tree_str(self.__wrapped__)

    def __copy__(self) -> _MaskBase[T]:
        return type(self)(tree_copy(self.__wrapped__))

    # raise helpful error message when trying to interact with frozen object
    __add__ = __radd__ = __iadd__ = _MaskedError("+")
    __sub__ = __rsub__ = __isub__ = _MaskedError("-")
    __mul__ = __rmul__ = __imul__ = _MaskedError("*")
    __matmul__ = __rmatmul__ = __imatmul__ = _MaskedError("@")
    __truediv__ = __rtruediv__ = __itruediv__ = _MaskedError("/")
    __floordiv__ = __rfloordiv__ = __ifloordiv__ = _MaskedError("//")
    __mod__ = __rmod__ = __imod__ = _MaskedError("%")
    __pow__ = __rpow__ = __ipow__ = _MaskedError("**")
    __lshift__ = __rlshift__ = __ilshift__ = _MaskedError("<<")
    __rshift__ = __rrshift__ = __irshift__ = _MaskedError(">>")
    __and__ = __rand__ = __iand__ = _MaskedError("and")
    __xor__ = __rxor__ = __ixor__ = _MaskedError("")
    __or__ = __ror__ = __ior__ = _MaskedError("or")
    __neg__ = __pos__ = __abs__ = __invert__ = _MaskedError("unary")
    __lt__ = __le__ = __gt__ = __ge__ = _MaskedError("comparison")
    __call__ = _MaskedError("__call__")


@tree_summary.def_type(_MaskBase)
def _(node) -> str:
    return f"#{tree_summary.type_dispatcher(node.__wrapped__)}"


class _MaskedHashable(_MaskBase):
    def __hash__(self) -> int:
        return tree_hash(self.__wrapped__)

    def __eq__(self, rhs: Any) -> bool:
        if not isinstance(rhs, _MaskedHashable):
            return False
        return is_tree_equal(self.__wrapped__, rhs.__wrapped__)


class _MaskedArray(_MaskBase):
    # wrap arrays with a custom wrapper that implements hash and equality
    # using the wrapped array's bytes representation and sha256 hash function
    # this is useful to select some array to hold without updating in the process
    # of training a model.
    def __hash__(self) -> int:
        bytes = arraylib.tobytes(self.__wrapped__)
        return int(hashlib.sha256(bytes).hexdigest(), 16)

    def __eq__(self, other) -> bool:
        if not isinstance(other, _MaskedArray):
            return False
        lhs, rhs = self.__wrapped__, other.__wrapped__
        # fast path to avoid calling `all` on large arrays
        if arraylib.shape(lhs) != arraylib.shape(rhs):
            return False
        if arraylib.dtype(lhs) != arraylib.dtype(rhs):
            return False
        return arraylib.array_equal(lhs, rhs)


def _tree_mask_map(
    tree: T,
    cond: Callable[[Any], bool],
    func: Callable[[Any], Any],
    *,
    is_leaf: Callable[[Any], None] | None = None,
):
    if not isinstance(cond, Callable):
        # a callable that accepts a leaf and returns a boolean
        # but *not* a tree with the same structure as tree with boolean values.
        raise TypeError(
            f"`cond` must be a callable that accepts a leaf and returns a boolean "
            f" Got {cond=} and {tree=}."
        )

    treelib = sepes._src.backend.treelib

    def map_func(x):
        return func(x) if cond(x) else x

    return treelib.map(map_func, tree, is_leaf=is_leaf)


[docs] def tree_mask( tree: T, cond: Callable[[Any], bool] = is_nondiff, *, is_leaf: Callable[[Any], None] | None = None, ): """Mask leaves of a pytree based on ``mask`` boolean pytree or callable. Masked leaves are wrapped with a wrapper that yields no leaves when ``tree_flatten`` is called on it. Args: tree: A pytree of values. cond: A callable that accepts a leaf and returns a boolean to mark the leaf for masking. Defaults to masking non-differentiable leaf nodes that are not instances of of python float, python complex, or inexact array types. is_leaf: A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example, ``is_leaf=lambda x: isinstance(x, list)`` will treat lists as leaves and will not recurse into them. Example: >>> import sepes as sp >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> sp.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}] Example: Pass non-differentiable values to ``jax.grad`` >>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) Example: Define a custom masking wrapper for a specific type. >>> import sepes as sp >>> import jax >>> import dataclasses as dc >>> @dc.dataclass ... class MyInt: ... value: int >>> @dc.dataclass ... class MaskedInt: ... value: MyInt >>> # define a rule of how to mask an integer >>> @sp.tree_mask.def_type(MyInt) ... def mask_int(value): ... return MaskedInt(value) >>> # define a rule how to unmask the wrapper >>> @sp.tree_unmask.def_type(MaskedInt) ... def unmask_int(value): ... return value.value >>> tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}] >>> masked_tree = sp.tree_mask(tree, cond=lambda _: True) >>> masked_tree [MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}] >>> sp.tree_unmask(masked_tree) [MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}] """ return _tree_mask_map( tree, cond=cond, func=tree_mask.dispatcher, is_leaf=is_leaf, )
tree_mask.dispatcher = ft.singledispatch(_MaskedHashable) tree_mask.def_type = tree_mask.dispatcher.register
[docs] def tree_unmask(tree: T, cond: Callable[[Any], bool] = lambda _: True): """Undo the masking of tree leaves according to ``cond``. defaults to unmasking all leaves. Args: tree: A pytree of values. cond: A callable that accepts a leaf and returns a boolean to mark the leaf to be unmasked. Defaults to always unmask. Example: >>> import sepes as sp >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> sp.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}] Example: Pass non-differentiable values to ``jax.grad`` >>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2) Example: Define a custom masking wrapper for a specific type. >>> import sepes as sp >>> import jax >>> import dataclasses as dc >>> @dc.dataclass ... class MyInt: ... value: int >>> @dc.dataclass ... class MaskedInt: ... value: MyInt >>> # define a rule of how to mask an integer >>> @sp.tree_mask.def_type(MyInt) ... def mask_int(value): ... return MaskedInt(value) >>> # define a rule how to unmask the wrapper >>> @sp.tree_unmask.def_type(MaskedInt) ... def unmask_int(value): ... return value.value >>> tree = [MyInt(1), MyInt(2), {"a": MyInt(3)}] >>> masked_tree = sp.tree_mask(tree, cond=lambda _: True) >>> masked_tree [MaskedInt(value=MyInt(value=1)), MaskedInt(value=MyInt(value=2)), {'a': MaskedInt(value=MyInt(value=3))}] >>> sp.tree_unmask(masked_tree) [MyInt(value=1), MyInt(value=2), {'a': MyInt(value=3)}] """ return _tree_mask_map( tree, cond=cond, func=tree_unmask.dispatcher, is_leaf=is_masked )
tree_unmask.dispatcher = ft.singledispatch(lambda x: x) tree_unmask.def_type = tree_unmask.dispatcher.register for ndarray in arraylib.ndarrays: @tree_mask.def_type(ndarray) def mask_array(value: T) -> _MaskedArray[T]: # wrap arrays with a custom wrapper that implements hash and equality # arrays can be hashed by converting them to bytes and hashing the bytes return _MaskedArray(value) @tree_mask.def_type(_MaskBase) def _(value: _MaskBase[T]) -> _MaskBase[T]: # idempotent mask operation, meaning that mask(mask(x)) == mask(x) # this is useful to avoid recursive unwrapping of frozen values, plus its # meaningless to mask a frozen value. return value
[docs] def is_masked(value: Any) -> bool: """Returns True if the value is a frozen wrapper.""" types = tuple(set(tree_unmask.dispatcher.registry) - {object}) return isinstance(value, types)
@tree_unmask.def_type(_MaskBase) def _(value: _MaskBase[T]) -> T: return getattr(value, "__wrapped__") if is_package_avaiable("jax"): import jax # do not touch jax.core.Tracer instances. # otherwise calling `freeze` inside a jax transformation on # a tracer will hide the tracer from jax and will cause leaked tracer # error. @tree_mask.def_type(jax.core.Tracer) def _(value: jax.core.Tracer) -> jax.core.Tracer: return value