# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import functools as ft
from typing import Sequence
import jax
import jax.random as jr
from serket import TreeClass
from serket._src.custom_transform import tree_eval
from serket._src.nn.linear import Identity
from serket._src.utils.convert import canonicalize, delayed_canonicalize_padding
from serket._src.utils.mapping import kernel_map
from serket._src.utils.typing import (
KernelSizeType,
MethodKind,
PaddingType,
StridesType,
)
from serket._src.utils.validate import validate_spatial_ndim
[docs]
def random_crop_nd(
key: jax.Array,
input: jax.Array,
crop_size: Sequence[int],
) -> jax.Array:
"""Crops an input to the given size at a random starts along each axis.
Args:
key: random key.
input: input array.
crop_size: size of the crop along each axis.Accepts a tuple of int.
"""
start: tuple[int, ...] = tuple(
jr.randint(key, shape=(), minval=0, maxval=input.shape[i] - s)
for i, s in enumerate(crop_size)
)
return jax.lax.dynamic_slice(input, start, crop_size)
[docs]
def center_crop_nd(input: jax.Array, sizes: tuple[int, ...]) -> jax.Array:
"""Crops an input to the given size at the center.
Args:
input: input array.
sizes: size of the crop along each axis.Accepts a tuple of int.
"""
shapes = input.shape
starts = tuple(max(shape // 2 - size // 2, 0) for shape, size in zip(shapes, sizes))
return jax.lax.dynamic_slice(input, starts, sizes)
[docs]
def upsample_nd(
input: jax.Array,
scale: int | Sequence[int],
method: MethodKind = "nearest",
) -> jax.Array:
"""Upsample a 1D input to a given size using a given interpolation method.
Args:
input: input array.
scale: the scale of the output. accetps a sequence of int denoting the scale
multiplier along each axis.
method: Interpolation method defaults to ``nearest``. choices are:
- ``nearest``: Nearest neighbor interpolation. The values of antialias
and precision are ignored.
- ``linear``, ``bilinear``, ``trilinear``, ``triangle``: Linear interpolation.
If ``antialias`` is True, uses a triangular filter when downsampling.
- ``cubic``, ``bicubic``, ``tricubic``: Cubic interpolation, using
the keys cubic kernel.
- ``lanczos3``: Lanczos resampling, using a kernel of radius 3.
- ``lanczos5``: Lanczos resampling, using a kernel of radius 5.
"""
resized_shape = tuple(s * input.shape[i] for i, s in enumerate(scale))
return jax.image.resize(input, resized_shape, method)
class UpsampleND(TreeClass):
def __init__(
self,
scale: int | tuple[int, ...] = 1,
method: MethodKind = "nearest",
):
# the difference between this and ResizeND is that UpsamplingND
# use scale instead of size
# assert types
self.scale = canonicalize(scale, self.spatial_ndim, name="scale")
self.method = method
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array) -> jax.Array:
in_axes = (0, None, None)
args = (input, self.scale, self.method)
return jax.vmap(upsample_nd, in_axes=in_axes)(*args)
@property
@abc.abstractmethod
def spatial_ndim(self) -> int:
"""Number of spatial dimensions of the image."""
...
[docs]
class Upsample1D(UpsampleND):
"""Upsample a 1D input to a given size using a given interpolation method.
Args:
scale: the scale of the output.
method: Interpolation method defaults to ``nearest``. choices are:
- ``nearest``: Nearest neighbor interpolation. The values of antialias
and precision are ignored.
- ``linear``, ``bilinear``, ``trilinear``, ``triangle``: Linear interpolation.
If ``antialias`` is True, uses a triangular filter when downsampling.
- ``cubic``, ``bicubic``, ``tricubic``: Cubic interpolation, using
the keys cubic kernel.
- ``lanczos3``: Lanczos resampling, using a kernel of radius 3.
- ``lanczos5``: Lanczos resampling, using a kernel of radius 5.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.nn.Upsample1D(scale=2)
>>> input = jnp.arange(1, 6).reshape(1, 5)
>>> print(layer(input))
[[1 1 2 2 3 3 4 4 5 5]]
"""
spatial_ndim: int = 1
[docs]
class Upsample2D(UpsampleND):
"""Upsample a 2D input to a given size using a given interpolation method.
Args:
scale: the scale of the output. accetps a single int or a tuple of two
int denoting the scale multiplier along each axis.
method: Interpolation method defaults to ``nearest``. choices are:
- ``nearest``: Nearest neighbor interpolation. The values of antialias
and precision are ignored.
- ``linear``, ``bilinear``, ``trilinear``, ``triangle``: Linear interpolation.
If ``antialias`` is True, uses a triangular filter when downsampling.
- ``cubic``, ``bicubic``, ``tricubic``: Cubic interpolation, using
the keys cubic kernel.
- ``lanczos3``: Lanczos resampling, using a kernel of radius 3.
- ``lanczos5``: Lanczos resampling, using a kernel of radius 5.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.nn.Upsample2D(scale=(1, 2))
>>> input = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(layer(input))
[[[ 1 1 2 2 3 3 4 4 5 5]
[ 6 6 7 7 8 8 9 9 10 10]
[11 11 12 12 13 13 14 14 15 15]
[16 16 17 17 18 18 19 19 20 20]
[21 21 22 22 23 23 24 24 25 25]]]
"""
spatial_ndim: int = 2
[docs]
class Upsample3D(UpsampleND):
"""Upsample a 1D input to a given size using a given interpolation method.
Args:
scale: the scale of the output. accetps a single int or a tuple of three
int denoting the scale multiplier along each axis.
method: Interpolation method defaults to ``nearest``. choices are:
- ``nearest``: Nearest neighbor interpolation. The values of antialias
and precision are ignored.
- ``linear``, ``bilinear``, ``trilinear``, ``triangle``: Linear interpolation.
If ``antialias`` is True, uses a triangular filter when downsampling.
- ``cubic``, ``bicubic``, ``tricubic``: Cubic interpolation, using
the keys cubic kernel.
- ``lanczos3``: Lanczos resampling, using a kernel of radius 3.
- ``lanczos5``: Lanczos resampling, using a kernel of radius 5.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.nn.Upsample3D(scale=(1, 2, 1))
>>> input = jnp.arange(1, 9).reshape(1, 2, 2, 2)
>>> print(layer(input))
[[[[1 2]
[1 2]
[3 4]
[3 4]]
<BLANKLINE>
[[5 6]
[5 6]
[7 8]
[7 8]]]]
"""
spatial_ndim: int = 3
class RandomCropND(TreeClass):
def __init__(self, size: int | tuple[int, ...]):
self.size = canonicalize(size, self.spatial_ndim, name="size")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array, *, key: jax.Array) -> jax.Array:
crop_size = (input.shape[0], *self.size)
return random_crop_nd(key, input, crop_size=crop_size)
@property
@abc.abstractmethod
def spatial_ndim(self) -> int:
"""Number of spatial dimensions of the image."""
...
[docs]
class RandomCrop1D(RandomCropND):
"""Crop a 1D input to the given size at a random start.
Args:
size: size of the slice, either a single int or a tuple of int. accepted
values are either a single int or a tuple of int denoting the size.
"""
spatial_ndim: int = 1
[docs]
class RandomCrop2D(RandomCropND):
"""Crop a 2D input to the given size at a random start.
Args:
size: size of the slice in each axis. accepted values are either a single int
or a tuple of two ints denoting the size along each axis.
"""
spatial_ndim: int = 2
[docs]
class RandomCrop3D(RandomCropND):
"""Crop a 3D input to the given size at a random start.
Args:
size: size of the slice in each axis. accepted values are either a single int
or a tuple of three ints denoting the size along each axis.
"""
spatial_ndim: int = 3
class CenterCropND(TreeClass):
def __init__(self, size: int | tuple[int, ...]):
self.size = canonicalize(size, self.spatial_ndim, name="size")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, input: jax.Array) -> jax.Array:
return jax.vmap(ft.partial(center_crop_nd, sizes=self.size))(input)
@property
@abc.abstractmethod
def spatial_ndim(self) -> int:
"""Number of spatial dimensions of the image."""
...
[docs]
class CenterCrop1D(CenterCropND):
"""Crops a 1D input to the given size at the center.
Args:
size: The size of the output image. accepts a single int.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.nn.CenterCrop1D(4)
>>> input = jnp.arange(1, 13).reshape(1, 12)
>>> print(input)
[[ 1 2 3 4 5 6 7 8 9 10 11 12]]
>>> print(layer(input))
[[5 6 7 8]]
"""
spatial_ndim: int = 1
[docs]
class CenterCrop2D(CenterCropND):
"""Crop the center of a channel-first image.
.. image:: ../_static/centercrop2d.png
Args:
size: The size of the output image. accepts a single int or a tuple of two ints.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.nn.CenterCrop2D(4)
>>> input = jnp.arange(1, 145).reshape(1, 12, 12)
>>> print(input)
[[[ 1 2 3 4 5 6 7 8 9 10 11 12]
[ 13 14 15 16 17 18 19 20 21 22 23 24]
[ 25 26 27 28 29 30 31 32 33 34 35 36]
[ 37 38 39 40 41 42 43 44 45 46 47 48]
[ 49 50 51 52 53 54 55 56 57 58 59 60]
[ 61 62 63 64 65 66 67 68 69 70 71 72]
[ 73 74 75 76 77 78 79 80 81 82 83 84]
[ 85 86 87 88 89 90 91 92 93 94 95 96]
[ 97 98 99 100 101 102 103 104 105 106 107 108]
[109 110 111 112 113 114 115 116 117 118 119 120]
[121 122 123 124 125 126 127 128 129 130 131 132]
[133 134 135 136 137 138 139 140 141 142 143 144]]]
>>> print(layer(input))
[[[53 54 55 56]
[65 66 67 68]
[77 78 79 80]
[89 90 91 92]]]
"""
spatial_ndim: int = 2
[docs]
class CenterCrop3D(CenterCropND):
"""Crops a 3D input to the given size at the center."""
spatial_ndim: int = 3
@tree_eval.def_eval(RandomCropND)
def _(_) -> Identity:
return Identity()