# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools as ft
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.ndimage import map_coordinates
from serket import TreeClass, autoinit, field
from serket._src.custom_transform import tree_eval
from serket._src.nn.linear import Identity
from serket._src.utils.typing import CHWArray, HWArray
from serket._src.utils.validate import IsInstance, Range, validate_spatial_ndim
def affine_2d(array: HWArray, matrix: HWArray) -> HWArray:
assert array.ndim == 2
h, w = array.shape
center = jnp.array((h // 2, w // 2))
coords = jnp.indices((h, w)).reshape(2, -1) - center.reshape(2, 1)
coords = matrix @ coords + center.reshape(2, 1)
return map_coordinates(array, coords, order=1).reshape((h, w))
[docs]
def horizontal_flip_2d(image: HWArray) -> HWArray:
assert image.ndim == 2
return jnp.flip(image, axis=1)
[docs]
def random_horizontal_flip_2d(key: jax.Array, image: HWArray, rate: float) -> HWArray:
prop = jr.bernoulli(key, rate)
return jnp.where(prop, horizontal_flip_2d(image), image)
[docs]
def vertical_flip_2d(image: HWArray) -> HWArray:
assert image.ndim == 2
return jnp.flip(image, axis=0)
[docs]
def random_vertical_flip_2d(key: jax.Array, image: HWArray, rate: float) -> HWArray:
prop = jr.bernoulli(key, rate)
return jnp.where(prop, vertical_flip_2d(image), image)
[docs]
def horizontal_shear_2d(image: HWArray, angle: float) -> HWArray:
"""shear rows by an angle in degrees"""
shear = jnp.tan(jnp.deg2rad(angle))
matrix = jnp.array([[1, 0], [shear, 1]])
return affine_2d(image, matrix)
[docs]
def random_horizontal_shear_2d(
key: jax.Array,
image: jax.Array,
range: tuple[float, float],
) -> jax.Array:
"""shear rows by an angle in degrees"""
minval, maxval = range
angle = jr.uniform(key=key, shape=(), minval=minval, maxval=maxval)
return horizontal_shear_2d(image, angle)
[docs]
def vertical_shear_2d(
image: HWArray,
angle: float,
) -> HWArray:
"""shear cols by an angle in degrees"""
shear = jnp.tan(jnp.deg2rad(angle))
matrix = jnp.array([[1, shear], [0, 1]])
return affine_2d(image, matrix)
[docs]
def random_vertical_shear_2d(
key: jax.Array,
image: HWArray,
range: tuple[float, float],
) -> HWArray:
"""shear cols by an angle in degrees"""
minval, maxval = range
angle = jr.uniform(key=key, shape=(), minval=minval, maxval=maxval)
return vertical_shear_2d(image, angle)
[docs]
def rotate_2d(image: HWArray, angle: float) -> HWArray:
"""rotate an image by an angle in degrees in CCW direction."""
θ = jnp.deg2rad(-angle)
matrix = jnp.array([[jnp.cos(θ), -jnp.sin(θ)], [jnp.sin(θ), jnp.cos(θ)]])
return affine_2d(image, matrix)
[docs]
def random_rotate_2d(
key: jax.Array,
image: HWArray,
range: tuple[float, float],
) -> HWArray:
minval, maxval = range
angle = jr.uniform(key=key, shape=(), minval=minval, maxval=maxval)
return rotate_2d(image, angle)
[docs]
def horizontal_translate_2d(image: HWArray, shift: int) -> HWArray:
"""Translate an image horizontally by a pixel value."""
assert image.ndim == 2
if shift > 0:
return jnp.zeros_like(image).at[:, shift:].set(image[:, :-shift])
if shift < 0:
return jnp.zeros_like(image).at[:, :shift].set(image[:, -shift:])
return image
def vertical_translate_2d(image: HWArray, shift: int) -> HWArray:
"""Translate an image vertically by a pixel value."""
assert image.ndim == 2
if shift > 0:
return jnp.zeros_like(image).at[shift:, :].set(image[:-shift, :])
if shift < 0:
return jnp.zeros_like(image).at[:shift, :].set(image[-shift:, :])
return image
[docs]
def random_horizontal_translate_2d(key: jax.Array, image: HWArray) -> HWArray:
_, w = image.shape
shift = jr.randint(key, shape=(), minval=-w, maxval=w)
return horizontal_translate_2d(image, shift)
[docs]
def random_vertical_translate_2d(key: jax.Array, image: HWArray) -> HWArray:
h, _ = image.shape
shift = jr.randint(key, shape=(), minval=-h, maxval=h)
return vertical_translate_2d(image, shift)
[docs]
class Rotate2D(TreeClass):
"""Rotate_2d a 2D image by an angle in dgrees in CCW direction
.. image:: ../_static/rotate2d.png
Args:
angle: angle to rotate_2d in degrees counter-clockwise direction.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.Rotate2D(90)(x))
[[[ 5 10 15 20 25]
[ 4 9 14 19 24]
[ 3 8 13 18 23]
[ 2 7 12 17 22]
[ 1 6 11 16 21]]]
"""
def __init__(self, angle: float):
self.angle = angle
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
angle = jax.lax.stop_gradient(self.angle)
return jax.vmap(rotate_2d, in_axes=(0, None))(image, angle)
spatial_ndim: int = 2
[docs]
class RandomRotate2D(TreeClass):
"""Rotate_2d a 2D image by an angle in dgrees in CCW direction
Args:
range: a tuple of min angle and max angle to randdomly choose from.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomRotate2D((10, 30))
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.RandomRotate2D((10, 30))(x, key=jax.random.key(0))) #doctest: +SKIP
[[[ 1 2 4 7 4]
[ 4 6 9 11 11]
[ 8 10 13 16 18]
[10 15 17 20 22]
[ 8 19 22 18 11]]]
"""
def __init__(self, range: tuple[float, float] = (0.0, 360.0)):
if not (
isinstance(range, tuple)
and len(range) == 2
and isinstance(range[0], (int, float))
and isinstance(range[1], (int, float))
):
raise ValueError(f"`{range=}` must be a tuple of 2 floats/ints ")
self.range = range
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
range = jax.lax.stop_gradient(self.range)
return jax.vmap(random_rotate_2d, in_axes=(None, 0, None))(key, image, range)
spatial_ndim: int = 2
[docs]
class HorizontalShear2D(TreeClass):
"""Shear an image horizontally
.. image:: ../_static/horizontalshear2d.png
Args:
angle: angle to rotate_2d in degrees counter-clockwise direction.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.HorizontalShear2D(45)(x))
[[[ 0 0 1 2 3]
[ 0 6 7 8 9]
[11 12 13 14 15]
[17 18 19 20 0]
[23 24 25 0 0]]]
"""
def __init__(self, angle: float):
self.angle = angle
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
angle = jax.lax.stop_gradient(self.angle)
return jax.vmap(horizontal_shear_2d, in_axes=(0, None))(image, angle)
spatial_ndim: int = 2
[docs]
class RandomHorizontalShear2D(TreeClass):
"""Shear an image horizontally with random angle choice.
Args:
range: a tuple of min angle and max angle to randdomly choose from.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomHorizontalShear2D((45, 45))
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.RandomHorizontalShear2D((45, 45))(x, key=jr.key(0)))
[[[ 0 0 1 2 3]
[ 0 6 7 8 9]
[11 12 13 14 15]
[17 18 19 20 0]
[23 24 25 0 0]]]
"""
def __init__(self, range: tuple[float, float] = (0.0, 90.0)):
if not (
isinstance(range, tuple)
and len(range) == 2
and isinstance(range[0], (int, float))
and isinstance(range[1], (int, float))
):
raise ValueError(f"`{range=}` must be a tuple of 2 floats")
self.range = range
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array | None = None) -> CHWArray:
angle = jax.lax.stop_gradient(self.range)
in_axes = (None, 0, None)
return jax.vmap(random_horizontal_shear_2d, in_axes=in_axes)(key, image, angle)
spatial_ndim: int = 2
[docs]
class VerticalShear2D(TreeClass):
"""Shear an image vertically
.. image:: ../_static/verticalshear2d.png
Args:
angle: angle to rotate_2d in degrees counter-clockwise direction.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.VerticalShear2D(45)(x))
[[[ 0 0 3 9 15]
[ 0 2 8 14 20]
[ 1 7 13 19 25]
[ 6 12 18 24 0]
[11 17 23 0 0]]]
"""
def __init__(self, angle: float):
self.angle = angle
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: jax.Array) -> jax.Array:
angle = jax.lax.stop_gradient(self.angle)
return jax.vmap(vertical_shear_2d, in_axes=(0, None))(image, angle)
spatial_ndim: int = 2
[docs]
class RandomVerticalShear2D(TreeClass):
"""Shear an image vertically with random angle choice.
Args:
range: a tuple of min angle and max angle to randdomly choose from.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomVerticalShear2D((45, 45))
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.RandomVerticalShear2D((45, 45))(x, key=jr.key(0)))
[[[ 0 0 3 9 15]
[ 0 2 8 14 20]
[ 1 7 13 19 25]
[ 6 12 18 24 0]
[11 17 23 0 0]]]
"""
def __init__(self, range: tuple[float, float] = (0.0, 90.0)):
if not (
isinstance(range, tuple)
and len(range) == 2
and isinstance(range[0], (int, float))
and isinstance(range[1], (int, float))
):
raise ValueError(f"`{range=}` must be a tuple of 2 floats")
self.range = range
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array | None = None) -> CHWArray:
angle = jax.lax.stop_gradient(self.range)
in_axes = (None, 0, None)
return jax.vmap(random_vertical_shear_2d, in_axes=in_axes)(key, image, angle)
spatial_ndim: int = 2
[docs]
@autoinit
class HorizontalTranslate2D(TreeClass):
"""Translate an image horizontally by a pixel value.
.. image:: ../_static/horizontaltranslate2d.png
Args:
shift: The number of pixels to shift the image by.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.HorizontalTranslate2D(2)(x))
[[[ 0 0 1 2 3]
[ 0 0 6 7 8]
[ 0 0 11 12 13]
[ 0 0 16 17 18]
[ 0 0 21 22 23]]]
"""
shift: int = field(on_setattr=[IsInstance(int)])
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(horizontal_translate_2d, in_axes=(0, None))(image, self.shift)
spatial_ndim: int = 2
[docs]
@autoinit
class VerticalTranslate2D(TreeClass):
"""Translate an image vertically by a pixel value.
.. image:: ../_static/verticaltranslate2d.png
Args:
shift: The number of pixels to shift the image by.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.VerticalTranslate2D(2)(x))
[[[ 0 0 0 0 0]
[ 0 0 0 0 0]
[ 1 2 3 4 5]
[ 6 7 8 9 10]
[11 12 13 14 15]]]
"""
shift: int = field(on_setattr=[IsInstance(int)])
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(vertical_translate_2d, in_axes=(0, None))(image, self.shift)
spatial_ndim: int = 2
[docs]
@autoinit
class RandomHorizontalTranslate2D(TreeClass):
"""Translate an image horizontally by a random pixel value.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomHorizontalTranslate2D()
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.RandomHorizontalTranslate2D()(x, key=jr.key(0))) #doctest: +SKIP
[[[ 4 5 0 0 0]
[ 9 10 0 0 0]
[14 15 0 0 0]
[19 20 0 0 0]
[24 25 0 0 0]]]
"""
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array | None = None) -> CHWArray:
return jax.vmap(random_horizontal_translate_2d, in_axes=(None, 0))(key, image)
spatial_ndim: int = 2
[docs]
class RandomVerticalTranslate2D(TreeClass):
"""Translate an image vertically by a random pixel value.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomVerticalTranslate2D()
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.RandomVerticalTranslate2D()(x, key=jr.key(0))) #doctest: +SKIP
[[[16 17 18 19 20]
[21 22 23 24 25]
[ 0 0 0 0 0]
[ 0 0 0 0 0]
[ 0 0 0 0 0]]]
"""
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array | None = None) -> CHWArray:
return jax.vmap(random_vertical_translate_2d, in_axes=(None, 0))(key, image)
spatial_ndim: int = 2
[docs]
class HorizontalFlip2D(TreeClass):
"""Flip channels left to right.
.. image:: ../_static/horizontalflip2d.png
Examples:
>>> import jax.numpy as jnp
>>> import serket as sk
>>> x = jnp.arange(1,10).reshape(1,3, 3)
>>> print(x)
[[[1 2 3]
[4 5 6]
[7 8 9]]]
>>> print(sk.image.HorizontalFlip2D()(x))
[[[3 2 1]
[6 5 4]
[9 8 7]]]
Reference:
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
"""
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(horizontal_flip_2d)(image)
spatial_ndim: int = 2
[docs]
@autoinit
class RandomHorizontalFlip2D(TreeClass):
"""Flip channels left to right with a probability of `rate`.
.. image:: ../_static/horizontalflip2d.png
Args:
rate: The probability of flipping the image.
Note:
use :func:`tree_eval` to replace this layer with :class:`Identity` during
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> key = jax.random.key(0)
>>> print(sk.image.RandomHorizontalFlip2D(rate=1.0)(x, key=key)) #doctest: +SKIP
[[[ 5 4 3 2 1]
[10 9 8 7 6]
[15 14 13 12 11]
[20 19 18 17 16]
[25 24 23 22 21]]]
"""
rate: float = field(on_setattr=[IsInstance(float), Range(0.0, 1.0)])
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array | None = None) -> CHWArray:
in_axes = (None, 0, None)
rate = jax.lax.stop_gradient(self.rate)
return jax.vmap(random_horizontal_flip_2d, in_axes=in_axes)(key, image, rate)
spatial_ndim: int = 2
[docs]
class VerticalFlip2D(TreeClass):
"""Flip channels up to down.
.. image:: ../_static/verticalflip2d.png
Examples:
>>> import jax.numpy as jnp
>>> import serket as sk
>>> x = jnp.arange(1,10).reshape(1,3, 3)
>>> print(x)
[[[1 2 3]
[4 5 6]
[7 8 9]]]
>>> print(sk.image.VerticalFlip2D()(x))
[[[7 8 9]
[4 5 6]
[1 2 3]]]
Reference:
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
"""
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(vertical_flip_2d)(image)
spatial_ndim: int = 2
[docs]
@autoinit
class RandomVerticalFlip2D(TreeClass):
"""Flip channels up to down with a probability of `rate`.
.. image:: ../_static/verticalflip2d.png
Args:
rate: The probability of flipping the image.
Note:
use :func:`tree_eval` to replace this layer with :class:`Identity` during
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> key = jax.random.key(0)
>>> print(sk.image.RandomVerticalFlip2D(rate=1.0)(x, key=key)) #doctest: +SKIP
[[[21 22 23 24 25]
[16 17 18 19 20]
[11 12 13 14 15]
[ 6 7 8 9 10]
[ 1 2 3 4 5]]]
"""
rate: float = field(on_setattr=[IsInstance(float), Range(0.0, 1.0)])
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array | None = None) -> CHWArray:
in_axes = (None, 0, None)
rate = jax.lax.stop_gradient(self.rate)
return jax.vmap(random_vertical_flip_2d, in_axes=in_axes)(key, image, rate)
spatial_ndim: int = 2
@tree_eval.def_eval(RandomRotate2D)
@tree_eval.def_eval(RandomHorizontalFlip2D)
@tree_eval.def_eval(RandomVerticalFlip2D)
@tree_eval.def_eval(RandomHorizontalShear2D)
@tree_eval.def_eval(RandomVerticalShear2D)
@tree_eval.def_eval(RandomHorizontalTranslate2D)
@tree_eval.def_eval(RandomVerticalTranslate2D)
def _(_):
return Identity()