Source code for serket._src.custom_transform

# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Define dispatchers for custom tree transformations."""

from __future__ import annotations

from inspect import getfullargspec
from typing import Any, TypeVar

import jax

import serket as sk
from serket._src.utils.dispatch import single_dispatch

T = TypeVar("T")


class NoState(sk.TreeClass):
    """No state placeholder."""

    def __init__(self, layer: Any, **_):
        del layer, _


[docs] def tree_state(tree: T, **kwargs) -> T: """Build state for a tree of layers. Some layers require state to be initialized before training. For example, :class:`nn.BatchNorm` layers requires ``running_mean`` and ``running_var`` to be initialized before training. This function initializes the state for a tree of layers, based on the layer defined ``state`` rule using ``tree_state.def_state``. :func:`.tree_state` objective is to provide a simple and consistent way to initialize state for a tree of layers. Specifically, it provides a way to separate the state initialization logic from the layer definition. This allows for a more clear separation of concerns, and makes it easier to define new layers. Args: tree: A tree of layers. kwargs: Keyword arguments to pass to the state initialization rule. of the tree layers. Returns: A tree of state leaves if it has state, otherwise ``NoState`` placeholder. Note: To define a state initialization rule for a custom layer, use the decorator :func:`.tree_state.def_state` on a function that accepts the layer as the first argument, for any additional arguments, use keyword only arguments. >>> import jax >>> import serket as sk >>> class LayerWithState(sk.TreeClass): ... pass >>> # state function accept the `layer` and input array >>> @sk.tree_state.def_state(LayerWithState) ... def _(leaf, *, input: jax.Array) -> jax.Array: ... return jax.random.normal(jax.random.key(0), input.shape) >>> sk.tree_state(LayerWithState(), input=jax.numpy.ones((1, 1))).shape (1, 1) Example: >>> import jax.numpy as jnp >>> import serket as sk >>> import jax.random as jr >>> tree = [1, 2, sk.nn.BatchNorm(5, key=jr.key(0))] >>> sk.tree_state(tree) [NoState(), NoState(), BatchNormState( running_mean=f32[5](Îŧ=0.00, ΃=0.00, ∈[0.00,0.00]), running_var=f32[5](Îŧ=1.00, ΃=0.00, ∈[1.00,1.00]) )] """ # tree_state handles state initialization for different layers # like RNN cells, BatchNorm, KMeans, etc. # one challenge is that the state initialization rule for a layer # may depend only on the layer itself, or may depend on the layer # and the input. For example, the state initialization rule for # ConvRNN Cells depends on the layer and sample input, but the state # initialization rule for some RNN cells (e.g. LSTM) does not depend on the # input. This poses a challenge for the user to pass the correct input # to the state initialization rule. types = tuple(set(tree_state.dispatcher.registry) - {object}) def is_leaf(node: Any) -> bool: return isinstance(node, types) def dispatch_func(leaf): try: return tree_state.dispatcher(leaf, **kwargs) except TypeError as e: # check if the leaf has a state rule for mro in type(leaf).__mro__[:-1]: if mro in (registry := tree_state.dispatcher.registry): func = registry[mro] break else: # no state rule is registered for this leaf # however type error is raised for other reasons raise type(e)(e) # the state rule is registered and the kwargs passed to `tree_state` # check if all necessary kwargs for this state rule are passed state_kwargs = getfullargspec(func).kwonlyargs if set(state_kwargs).issubset(set(kwargs)): # the state rule is registered and the kwargs passed to `tree_state` return func(leaf, **{key: kwargs[key] for key in state_kwargs}) # the state rule is registered and the kwargs passed to `tree_state` # are not a subset of the kwargs needed by the state rule (not found) raise type(e)( f"{type(leaf)=} has a registered state rule {sk.tree_str(func)}." f"\nHowever, the kwargs = {','.join(set(kwargs)-set(state_kwargs))}" f"are not passed to the state rule.\n{e}" ) return jax.tree_util.tree_map(dispatch_func, tree, is_leaf=is_leaf)
tree_state.dispatcher = single_dispatch(argnum=0)(NoState) tree_state.def_state = tree_state.dispatcher.def_type
[docs] def tree_eval(tree): """Modify tree layers to disable any trainning related behavior. For example, :class:`nn.Dropout` layer is replaced by an :class:`nn.Identity` layer and :class:`nn.BatchNorm` layer is replaced by :class:`.EvalBatchNorm` layer when evaluating the tree. :func:`.tree_eval` objective is to provide a simple and consistent way to disable any trainning related behavior for a tree of layers. Specifically, it provides a way to separate the evaluation logic from the layer definition. This allows for a more clear separation of concerns, and makes it easier to define new layers that has a single behavior. Args: tree: A tree of layers. Returns: A tree of layers with evaluation behavior of same structure as ``tree``. Example: >>> # dropout is replaced by an identity layer in evaluation mode >>> # by registering `tree_eval.def_eval(sk.nn.Dropout, sk.nn.Identity)` >>> import jax.numpy as jnp >>> import serket as sk >>> layer = sk.nn.Dropout(0.5) >>> sk.tree_eval(layer) Identity() Note: To define evaluation rule for a custom layer, use the decorator :func:`.tree_eval.def_eval` on a function that accepts the layer. The function should return the evaluation layer. >>> import serket as sk >>> import jax >>> class AddOne(sk.TreeClass): ... def __call__(self, input: jax.Array) -> jax.Array: ... return input + 1 >>> input = jax.numpy.ones([3, 3]) >>> add_one = AddOne() >>> print(add_one(input)) # add one to each element [[2. 2. 2.] [2. 2. 2.] [2. 2. 2.]] <BLANKLINE> >>> class AddOneEval(sk.TreeClass): ... def __call__(self, input: jax.Array) -> jax.Array: ... return input # no-op <BLANKLINE> >>> # register `AddOne` to be replaced by `AddOneEval` in evaluation mode >>> @sk.tree_eval.def_eval(AddOne) ... def _(_: AddOne) -> AddOneEval: ... return AddOneEval() >>> print(sk.tree_eval(add_one)(input)) [[1. 1. 1.] [1. 1. 1.] [1. 1. 1.]] """ types = tuple(set(tree_eval.dispatcher.registry) - {object}) def is_leaf(node: Any) -> bool: return isinstance(node, types) return jax.tree_util.tree_map(tree_eval.dispatcher, tree, is_leaf=is_leaf)
tree_eval.dispatcher = single_dispatch(argnum=0)(lambda x: x) tree_eval.def_eval = tree_eval.dispatcher.def_type