Source code for serket._src.containers
# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools as ft
from typing import Any, Callable, Sequence
import jax
import jax.random as jr
from serket import TreeClass, tree_summary
from serket._src.utils.dispatch import single_dispatch
@single_dispatch(argnum=0)
def sequential(key: jax.Array, _1, _2):
raise TypeError(f"Invalid {type(key)=}")
@sequential.def_type(type(None))
def _(key: None, layers: Sequence[Callable[..., Any]], array: Any):
del key # no key is supplied then no random number generation is needed
return ft.reduce(lambda x, layer: layer(x), layers, array)
@sequential.def_type(jax.Array)
def _(key: jax.Array, layers: Sequence[Callable[..., Any]], array: Any):
"""Applies a sequence of layers to an array.
Args:
key: a random number generator key supplied to the layers.
layers: a tuple callables.
array: an array to apply the layers to.
"""
for key, layer in zip(jr.split(key, len(layers)), layers):
try:
array = layer(array, key=key)
except TypeError:
array = layer(array)
return array
[docs]
class Sequential(TreeClass):
"""A sequential container for layers.
Args:
layers: a tuple or a list of layers. if a list is passed, it will
be casted to a tuple to maintain immutable behavior.
Example:
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import serket as sk
>>> layers = sk.Sequential(lambda x: x + 1, lambda x: x * 2)
>>> print(layers(jnp.array([1, 2, 3]), key=jr.key(0)))
[4 6 8]
Note:
Layer might be a function or a class with a ``__call__`` method, additionally
it might have a key argument for random number generation.
"""
def __init__(self, *layers):
# use var args to enforce tuple type to maintain immutability
self.layers = layers
[docs]
def __call__(self, input: jax.Array, *, key: jax.Array | None = None) -> jax.Array:
return sequential(key, self.layers, input)
@single_dispatch(argnum=1)
def __getitem__(self, key):
raise TypeError(f"Invalid index type: {type(key)}")
@__getitem__.def_type(slice)
def _(self, key: slice):
# return a new Sequential object with the sliced layers
return type(self)(*self.layers[key])
@__getitem__.def_type(int)
def _(self, key: int):
return self.layers[key]
def __len__(self):
return len(self.layers)
def __iter__(self):
return iter(self.layers)
def __reversed__(self):
return reversed(self.layers)
@tree_summary.def_type(Sequential)
def _(node):
types = [type(x).__name__ for x in node]
return f"{type(node).__name__}[{','.join(types)}]"