# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import functools as ft
from typing import Callable, Sequence
import jax
import jax.numpy as jnp
from typing_extensions import Annotated
from serket import TreeClass
from serket._src.utils.convert import canonicalize, delayed_canonicalize_padding
from serket._src.utils.mapping import kernel_map
from serket._src.utils.typing import KernelSizeType, PaddingType, StridesType
from serket._src.utils.validate import validate_spatial_ndim
def pool_nd(
reducer: Callable[[jax.Array], jax.Array],
inital_value: float,
input: Annotated[jax.Array, "I..."],
kernel_size: Sequence[int],
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
):
"""Pooling operation
Args:
reducer: reducer function. Takes an input and returns a single value
input: channeled input of shape (channels, spatial_dims)
kernel_size: size of the kernel. accepts a sequence of ints for each spatial dimension
strides: strides of the kernel. accepts a sequence of ints for each spatial dimension
padding: padding of the kernel. accepts a sequence of tuples of two ints for
each spatial dimension for each side of the input
"""
_, *S = input.shape
@jax.vmap
@ft.partial(
kernel_map,
shape=S,
kernel_size=kernel_size,
strides=strides,
padding=padding,
padding_mode=inital_value,
)
def reducer_map(view):
return reducer(view)
return reducer_map(input)
max_op = jax.custom_jvp(lambda x: jnp.maximum(jnp.max(x), -jnp.inf))
@max_op.defjvp
def _(primals, tangents):
(x,), (g,) = primals, tangents
return max_op(x), g.ravel()[x.argmax()]
[docs]
def max_pool_nd(
input: jax.Array,
kernel_size: Sequence[int],
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
) -> jax.Array:
"""Max pooling operation
Args:
input: channeled input of shape (channels, spatial_dims)
kernel_size: size of the kernel. accepts a sequence of ints for each spatial dimension
strides: strides of the kernel. accepts a sequence of ints for each spatial dimension
padding: padding of the kernel. accepts a sequence of tuples of two ints for
each spatial dimension for each side of the input
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> kernel_size = (3, 3)
>>> strides = (2, 2)
>>> input = jnp.ones((2, 25, 25))
>>> padding = ((1, 1), (1, 1)) # pad 1 on each side of the spatial dimensions
>>> output = sk.nn.max_pool_nd(input, kernel_size, strides, padding)
>>> print(output.shape)
(2, 13, 13)
"""
return pool_nd(max_op, -jnp.inf, input, kernel_size, strides, padding)
[docs]
def avg_pool_nd(
input: jax.Array,
kernel_size: Sequence[int],
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
) -> jax.Array:
"""Average pooling operation
Args:
input: channeled input of shape (channels, spatial_dims)
kernel_size: size of the kernel. accepts tuple of ints for each spatial dimension
strides: strides of the kernel. accepts tuple of ints for each spatial dimension
padding: padding of the kernel. accepts tuple of tuples of two ints for
each spatial dimension for each side of the input
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> kernel_size = (3, 3)
>>> strides = (2, 2)
>>> input = jnp.ones((2, 25, 25))
>>> padding = ((1, 1), (1, 1)) # pad 1 on each side of the spatial dimensions
>>> output = sk.nn.avg_pool_nd(input, kernel_size, strides, padding)
>>> print(output.shape)
(2, 13, 13)
"""
return pool_nd(jnp.mean, 0, input, kernel_size, strides, padding)
[docs]
def lp_pool_nd(
input: jax.Array,
norm_type: float,
kernel_size: Sequence[int],
strides: Sequence[int],
padding: Sequence[tuple[int, int]],
) -> jax.Array:
"""Lp pooling operation
Args:
input: channeled input of shape (channels, spatial_dims)
norm_type: norm type as a float
kernel_size: size of the kernel. accepts tuple of ints for each spatial dimension
strides: strides of the kernel. accepts tuple of ints for each spatial dimension
padding: padding of the kernel. accepts tuple of tuples of two ints for
each spatial dimension for each side of the input
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> kernel_size = (3, 3)
>>> strides = (2, 2)
>>> input = jnp.ones((2, 25, 25))
>>> norm_type = 2
>>> padding = ((1, 1), (1, 1)) # pad 1 on each side of the spatial dimensions
>>> output = sk.nn.lp_pool_nd(input, norm_type, kernel_size, strides, padding)
>>> print(output.shape)
(2, 13, 13)
"""
def reducer(input: jax.Array) -> jax.Array:
return jnp.sum(input**norm_type) ** (1 / norm_type)
return pool_nd(reducer, 0, input, kernel_size, strides, padding)
def adaptive_pool_nd(
reducer: Callable[[jax.Array], jax.Array],
input: jax.Array,
out_dim: Sequence[int],
) -> jax.Array:
in_dim = input.shape[1:]
strides = tuple(i // o for i, o in zip(in_dim, out_dim))
kernel_size = tuple(i - (o - 1) * s for i, o, s in zip(in_dim, out_dim, strides))
@jax.vmap
@ft.partial(
kernel_map,
shape=in_dim,
kernel_size=kernel_size,
strides=strides,
padding=((0, 0),) * len(in_dim),
)
def reducer_map(view: jax.Array) -> jax.Array:
return reducer(view)
return reducer_map(input)
[docs]
def adaptive_avg_pool_nd(input: jax.Array, out_dim: Sequence[int]) -> jax.Array:
"""Adaptive average pooling operation
Args:
input: channeled input of shape (channels, spatial_dims)
out_dim: output dimension. accepts a sequence of ints for each spatial dimension
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> input = jnp.ones((2, 25, 25))
>>> out_dim = (13, 13)
>>> output = sk.nn.adaptive_avg_pool_nd(input, out_dim)
>>> print(output.shape)
(2, 13, 13)
"""
return adaptive_pool_nd(jnp.mean, input, out_dim)
[docs]
def adaptive_max_pool_nd(input: jax.Array, out_dim: Sequence[int]) -> jax.Array:
"""Adaptive max pooling operation
Args:
input: channeled input of shape (channels, spatial_dims)
out_dim: output dimension. accepts a sequence of ints for each spatial dimension
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> input = jnp.ones((2, 25, 25))
>>> out_dim = (13, 13)
>>> output = sk.nn.adaptive_max_pool_nd(input, out_dim)
>>> print(output.shape)
(2, 13, 13)
"""
return adaptive_pool_nd(max_op, input, out_dim)
class MaxPoolND(TreeClass):
def __init__(
self,
kernel_size: KernelSizeType,
strides: StridesType = 1,
*,
padding: PaddingType = "valid",
):
self.kernel_size = canonicalize(
kernel_size,
self.spatial_ndim,
name="kernel_size",
)
self.strides = canonicalize(strides, self.spatial_ndim, name="strides")
self.padding = padding
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array) -> jax.Array:
padding = delayed_canonicalize_padding(
in_dim=input.shape,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
)
return max_pool_nd(
input=input,
kernel_size=self.kernel_size,
strides=self.strides,
padding=padding,
)
spatial_ndim = property(abc.abstractmethod(lambda _: ...))
[docs]
class MaxPool1D(MaxPoolND):
"""1D Max Pooling layer
Args:
kernel_size: size of the kernel
strides: strides of the kernel
padding: padding of the kernel (valid, same) or tuple of ints
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> layer = sk.nn.MaxPool1D(kernel_size=2, strides=2)
>>> x = jnp.arange(1, 11).reshape(1, 10).astype(jnp.float32)
>>> print(layer(x))
[[ 2. 4. 6. 8. 10.]]
"""
spatial_ndim: int = 1
[docs]
class MaxPool2D(MaxPoolND):
"""2D Max Pooling layer
Args:
kernel_size: size of the kernel
strides: strides of the kernel
padding: padding of the kernel (valid, same) or tuple of ints
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> layer = sk.nn.MaxPool2D(kernel_size=2, strides=2)
>>> x = jnp.arange(1, 17).reshape(1, 4, 4).astype(jnp.float32)
>>> print(layer(x))
[[[ 6. 8.]
[14. 16.]]]
"""
spatial_ndim: int = 2
[docs]
class MaxPool3D(MaxPoolND):
"""3D Max Pooling layer
Args:
kernel_size: size of the kernel
strides: strides of the kernel
padding: padding of the kernel (valid, same) or tuple of ints
"""
spatial_ndim: int = 3
class AvgPoolND(TreeClass):
def __init__(
self,
kernel_size: KernelSizeType,
strides: StridesType = 1,
*,
padding: PaddingType = "valid",
):
self.kernel_size = canonicalize(
kernel_size,
self.spatial_ndim,
name="kernel_size",
)
self.strides = canonicalize(strides, self.spatial_ndim, name="strides")
self.padding = padding
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array) -> jax.Array:
padding = delayed_canonicalize_padding(
in_dim=input.shape,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
)
return avg_pool_nd(
input=input,
kernel_size=self.kernel_size,
strides=self.strides,
padding=padding,
)
spatial_ndim = property(abc.abstractmethod(lambda _: ...))
[docs]
class AvgPool1D(AvgPoolND):
"""1D Average Pooling layer
Args:
kernel_size: size of the kernel
strides: strides of the kernel
padding: padding of the kernel (valid, same) or tuple of ints
"""
spatial_ndim: int = 1
[docs]
class AvgPool2D(AvgPoolND):
"""2D Average Pooling layer
Args:
kernel_size: size of the kernel
strides: strides of the kernel
padding: padding of the kernel (valid, same) or tuple of ints
"""
spatial_ndim: int = 2
[docs]
class AvgPool3D(AvgPoolND):
"""3D Average Pooling layer
Args:
kernel_size: size of the kernel
strides: strides of the kernel
padding: padding of the kernel (valid, same) or tuple of ints
"""
spatial_ndim: int = 3
class LPPoolND(TreeClass):
def __init__(
self,
norm_type: float,
kernel_size: KernelSizeType,
strides: StridesType = 1,
*,
padding: PaddingType = "valid",
):
self.norm_type = norm_type
self.kernel_size = canonicalize(
kernel_size,
self.spatial_ndim,
name="kernel_size",
)
self.strides = canonicalize(strides, self.spatial_ndim, name="strides")
self.padding = padding
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array) -> jax.Array:
padding = delayed_canonicalize_padding(
in_dim=input.shape,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
)
return lp_pool_nd(
input=input,
norm_type=self.norm_type,
kernel_size=self.kernel_size,
strides=self.strides,
padding=padding,
)
spatial_ndim = property(abc.abstractmethod(lambda _: ...))
[docs]
class LPPool1D(LPPoolND):
"""1D Lp pooling to the input.
Args:
norm_type: norm type
kernel_size: size of the kernel
strides: strides of the kernel
padding: padding of the kernel
"""
spatial_ndim: int = 1
[docs]
class LPPool2D(LPPoolND):
"""2D Lp pooling to the input.
Args:
norm_type: norm type
kernel_size: size of the kernel
strides: strides of the kernel
padding: padding of the kernel
"""
spatial_ndim: int = 2
[docs]
class LPPool3D(LPPoolND):
"""3D Lp pooling to the input.
Args:
norm_type: norm type
kernel_size: size of the kernel
strides: strides of the kernel
padding: padding of the kernel
"""
spatial_ndim: int = 3
class GlobalAvgPoolND(TreeClass):
def __init__(self, keepdims: bool = True):
self.keepdims = keepdims
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array) -> jax.Array:
axes = tuple(range(1, self.spatial_ndim + 1)) # reduce spatial dimensions
return jnp.mean(input, axis=axes, keepdims=self.keepdims)
spatial_ndim = property(abc.abstractmethod(lambda _: ...))
[docs]
class GlobalAvgPool1D(GlobalAvgPoolND):
"""1D Global Average Pooling layer
Args:
keepdims: whether to keep the dimensions or not
"""
spatial_ndim: int = 1
[docs]
class GlobalAvgPool2D(GlobalAvgPoolND):
"""2D Global Average Pooling layer
Args:
keepdims: whether to keep the dimensions or not
"""
spatial_ndim: int = 2
[docs]
class GlobalAvgPool3D(GlobalAvgPoolND):
"""3D Global Average Pooling layer
Args:
keepdims: whether to keep the dimensions or not
"""
spatial_ndim: int = 3
class GlobalMaxPoolND(TreeClass):
def __init__(self, keepdims: bool = True):
self.keepdims = keepdims
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array) -> jax.Array:
axes = tuple(range(1, self.spatial_ndim + 1)) # reduce spatial dimensions
return jnp.max(input, axis=axes, keepdims=self.keepdims)
spatial_ndim = property(abc.abstractmethod(lambda _: ...))
[docs]
class GlobalMaxPool1D(GlobalMaxPoolND):
"""1D Global Max Pooling layer
Args:
keepdims: whether to keep the dimensions or not
"""
spatial_ndim: int = 1
[docs]
class GlobalMaxPool2D(GlobalMaxPoolND):
"""2D Global Max Pooling layer
Args:
keepdims: whether to keep the dimensions or not
"""
spatial_ndim: int = 2
[docs]
class GlobalMaxPool3D(GlobalMaxPoolND):
"""3D Global Max Pooling layer
Args:
keepdims: whether to keep the dimensions or not
"""
spatial_ndim: int = 3
class AdaptiveAvgPoolND(TreeClass):
def __init__(self, output_size: tuple[int, ...]):
self.output_size = canonicalize(
output_size,
self.spatial_ndim,
name="output_size",
)
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array) -> jax.Array:
return adaptive_avg_pool_nd(input, self.output_size)
spatial_ndim = property(abc.abstractmethod(lambda _: ...))
[docs]
class AdaptiveAvgPool1D(AdaptiveAvgPoolND):
"""1D Adaptive Average Pooling layer
Args:
output_size: size of the output
"""
spatial_ndim: int = 1
[docs]
class AdaptiveAvgPool2D(AdaptiveAvgPoolND):
"""2D Adaptive Average Pooling layer
Args:
output_size: size of the output
"""
spatial_ndim: int = 2
[docs]
class AdaptiveAvgPool3D(AdaptiveAvgPoolND):
"""3D Adaptive Average Pooling layer
Args:
output_size: size of the output
"""
spatial_ndim: int = 3
class AdaptiveMaxPoolND(TreeClass):
def __init__(self, output_size: tuple[int, ...]):
self.output_size = canonicalize(
output_size,
self.spatial_ndim,
name="output_size",
)
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array) -> jax.Array:
return adaptive_max_pool_nd(input, self.output_size)
spatial_ndim = property(abc.abstractmethod(lambda _: ...))
[docs]
class AdaptiveMaxPool1D(AdaptiveMaxPoolND):
"""1D Adaptive Max Pooling layer
Args:
output_size: size of the output
"""
spatial_ndim: int = 1
[docs]
class AdaptiveMaxPool2D(AdaptiveMaxPoolND):
"""2D Adaptive Max Pooling layer
Args:
output_size: size of the output
"""
spatial_ndim: int = 2
[docs]
class AdaptiveMaxPool3D(AdaptiveMaxPoolND):
"""3D Adaptive Max Pooling layer
Args:
output_size: size of the output
"""
spatial_ndim: int = 3