# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
import functools as ft
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.ndimage import map_coordinates
from serket import TreeClass
from serket._src.image.geometric import rotate_2d
from serket._src.nn.convolution import (
fft_conv_general_dilated,
generate_conv_dim_numbers,
)
from serket._src.utils.convert import (
canonicalize,
delayed_canonicalize_padding,
resolve_string_padding,
)
from serket._src.utils.mapping import kernel_map
from serket._src.utils.typing import CHWArray, DType, HWArray
from serket._src.utils.validate import validate_spatial_ndim
[docs]
def filter_2d(
image: HWArray,
weight: HWArray,
strides: tuple[int, int] = (1, 1),
) -> HWArray:
"""Filtering wrapping ``jax.lax.conv_general_dilated``.
Args:
image: 2D input array. shape is ``(height, width)``.
weight: convolutional kernel. shape is ``(height, width)``.
strides: stride of the convolution. Accepts tuple of two ints.
"""
assert image.ndim == 2
assert weight.ndim == 2
image = jnp.expand_dims(image, 0)
weight = jnp.expand_dims(weight, (0, 1))
x = jax.lax.conv_general_dilated(
lhs=jnp.expand_dims(image, 0),
rhs=weight,
window_strides=strides,
padding="same",
rhs_dilation=(1, 1),
dimension_numbers=generate_conv_dim_numbers(2),
feature_group_count=1,
)
return jnp.squeeze(x, (0, 1))
[docs]
def fft_filter_2d(
image: HWArray,
weight: HWArray,
strides: tuple[int, int] = (1, 1),
) -> HWArray:
"""Filtering wrapping ``serket`` ``fft_conv_general_dilated``
Args:
image: 2D input array. shape is (height, width).
weight: convolutional kernel. shape is (height, width).
strides: stride of the convolution. Accepts tuple of two ints.
"""
assert image.ndim == 2
assert weight.ndim == 2
array = jnp.expand_dims(image, 0)
weight = jnp.expand_dims(weight, (0, 1))
padding = resolve_string_padding(
in_dim=array.shape[1:],
padding="same",
kernel_size=weight.shape[2:],
strides=strides,
)
x = fft_conv_general_dilated(
lhs=jnp.expand_dims(array, 0),
rhs=weight,
strides=strides,
padding=padding,
dilation=(1, 1),
groups=1,
)
return jnp.squeeze(x, (0, 1))
def calculate_average_kernel_1d(kernel_size: int, dtype: DType) -> HWArray:
"""Calculate average kernel.
Args:
kernel_size: size of the convolving kernel. Accept an int.
dtype: data type of the kernel.
Returns:
Average kernel. shape is (kernel_size, ).
"""
kernel = jnp.ones((kernel_size), dtype=dtype)
kernel = kernel / jnp.sum(kernel)
return kernel.astype(dtype)
[docs]
def avg_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
dtype: DType | None = None,
) -> HWArray:
"""Average blur.
Args:
image: 2D input array. shape is (height, width).
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
Returns:
Average blurred array. shape is (height, width).
"""
dtype = dtype or image.dtype
(ky, kx) = kernel_size
kernel_x = calculate_average_kernel_1d(kx, dtype)[None] # (1, kx)
kernel_y = calculate_average_kernel_1d(ky, dtype)[None] # (1, ky)
return filter_2d(filter_2d(image, kernel_x), kernel_y.T)
[docs]
def fft_avg_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
dtype: DType | None = None,
) -> HWArray:
"""Average blur using FFT.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
Returns:
Average blurred array. shape is ``(height, width)``.
"""
dtype = dtype or image.dtype
(ky, kx) = kernel_size
kernel_x = calculate_average_kernel_1d(kx, dtype)[None] # (1, kx)
kernel_y = calculate_average_kernel_1d(ky, dtype)[None] # (1, ky)
return fft_filter_2d(fft_filter_2d(image, kernel_x), kernel_y.T)
def calculate_gaussian_kernel_1d(
kernel_size: int,
sigma: float,
dtype: DType,
) -> HWArray:
"""Calculate gaussian kernel.
Args:
kernel_size: size of the convolving kernel. Accept an int.
sigma: sigma of gaussian kernel.
dtype: data type of the kernel.
Returns:
gaussian kernel. shape is (kernel_size, ).
"""
x = jnp.arange(kernel_size, dtype=dtype) - kernel_size // 2
x = x + 0.5 if kernel_size % 2 == 0 else x
kernel = jnp.exp(-(x**2) / (2 * sigma**2))
kernel = kernel / jnp.sum(kernel)
return kernel.astype(dtype)
[docs]
def gaussian_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
sigma: tuple[float, float],
dtype: DType | None = None,
) -> HWArray:
"""Gaussian blur.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
sigma: sigma of gaussian kernel.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
"""
dtype = dtype or image.dtype
(ky, kx), (sy, sx) = kernel_size, sigma
gy = calculate_gaussian_kernel_1d(ky, sy, dtype)[None]
gx = calculate_gaussian_kernel_1d(kx, sx, dtype)[None]
return filter_2d(filter_2d(image, gy), gx.T)
[docs]
def fft_gaussian_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
sigma: tuple[float, float],
dtype: DType | None = None,
) -> HWArray:
"""Gaussian blur using FFT.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
sigma: sigma of gaussian kernel.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
"""
dtype = dtype or image.dtype
(ky, kx), (sy, sx) = kernel_size, sigma
gy = calculate_gaussian_kernel_1d(ky, sy, dtype)[None]
gx = calculate_gaussian_kernel_1d(kx, sx, dtype)[None]
return fft_filter_2d(fft_filter_2d(image, gy), gx.T)
[docs]
def unsharp_mask_2d(
image: HWArray,
kernel_size: tuple[int, int],
sigma: tuple[float, float],
dtype: DType | None = None,
) -> HWArray:
"""Unsharp mask.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
sigma: sigma of gaussian kernel.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
Returns:
Unsharp masked array. shape is ``(height, width)``.
"""
dtype = dtype or image.dtype
(ky, kx), (sy, sx) = kernel_size, sigma
gy = calculate_gaussian_kernel_1d(ky, sy, dtype)[None]
gx = calculate_gaussian_kernel_1d(kx, sx, dtype)[None]
blur = filter_2d(filter_2d(image, gy), gx.T)
return image + (image - blur)
[docs]
def fft_unsharp_mask_2d(
image: HWArray,
kernel_size: tuple[int, int],
sigma: tuple[float, float],
dtype: DType | None = None,
) -> HWArray:
"""Unsharp mask using FFT.
Args:
image: 2D input array. shape is (height, width).
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
sigma: sigma of gaussian kernel.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
Returns:
Unsharp masked array. shape is (height, width).
"""
dtype = dtype or image.dtype
(ky, kx), (sy, sx) = kernel_size, sigma
gy = calculate_gaussian_kernel_1d(ky, sy, dtype)[None]
gx = calculate_gaussian_kernel_1d(kx, sx, dtype)[None]
blur = fft_filter_2d(fft_filter_2d(image, gy), gx.T)
return image + (image - blur)
def calculate_box_kernel_1d(kernel_size: int, dtype: DType) -> HWArray:
"""Calculate box kernel.
Args:
kernel_size: size of the convolving kernel. Accept an int.
dtype: data type of the kernel.
Returns:
Box kernel. shape is (1, kernel_size).
"""
kernel = jnp.ones((kernel_size))
kernel = kernel.astype(dtype)
return kernel / kernel_size
[docs]
def box_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
dtype: DType | None = None,
) -> HWArray:
"""Box blur.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
Returns:
Box blurred array. shape is ``(height, width)``.
"""
dtype = dtype or image.dtype
(ky, kx) = kernel_size
kernel_x = calculate_box_kernel_1d(kx, dtype)[None]
kernel_y = calculate_box_kernel_1d(ky, dtype)[None]
return filter_2d(filter_2d(image, kernel_x), kernel_y.T)
[docs]
def fft_box_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
dtype: DType | None = None,
) -> HWArray:
"""Box blur using FFT.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
Returns:
Box blurred array. shape is ``(height, width)``.
"""
dtype = dtype or image.dtype
(ky, kx) = kernel_size
kernel_x = calculate_box_kernel_1d(kx, dtype)[None]
kernel_y = calculate_box_kernel_1d(ky, dtype)[None]
return fft_filter_2d(fft_filter_2d(image, kernel_x), kernel_y.T)
def calculate_laplacian_kernel_2d(
kernel_size: tuple[int, int], dtype: DType
) -> HWArray:
"""Calculate laplacian kernel.
Args:
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
dtype: data type of the kernel.
Returns:
Laplacian kernel. shape is (kernel_size[0], kernel_size[1]).
"""
ky, kx = kernel_size
kernel = jnp.ones((ky, kx))
kernel = kernel.at[ky // 2, kx // 2].set(1 - jnp.sum(kernel)).astype(dtype)
return kernel
[docs]
def laplacian_2d(
image: HWArray,
kernel_size: tuple[int, int],
dtype: DType | None = None,
) -> HWArray:
"""Laplacian.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
Returns:
Laplacian array. shape is ``(height, width)``.
"""
dtype = dtype or image.dtype
kernel = calculate_laplacian_kernel_2d(kernel_size, dtype)
return filter_2d(image, kernel)
[docs]
def fft_laplacian_2d(
image: HWArray,
kernel_size: tuple[int, int],
dtype: DType | None = None,
) -> HWArray:
"""Laplacian using FFT.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
Returns:
Laplacian array. shape is ``(height, width)``.
"""
dtype = dtype or image.dtype
kernel = calculate_laplacian_kernel_2d(kernel_size, dtype)
return fft_filter_2d(image, kernel)
def calculate_motion_kernel_2d(
kernel_size: int,
angle: float,
direction,
dtype: DType,
) -> HWArray:
"""Returns 2D motion blur filter.
Args:
kernel_size: motion kernel width and height. It should be odd and positive.
angle: angle of the motion blur in degrees (anti-clockwise rotation).
direction: direction of the motion blur.
Returns:
The motion blur kernel of shape (kernel_size, kernel_size).
"""
kernel = jnp.zeros((kernel_size, kernel_size))
direction = (jnp.clip(direction, -1.0, 1.0) + 1.0) / 2.0
indices = jnp.arange(kernel_size)
set_value = direction + ((1 - 2 * direction) / (kernel_size - 1)) * indices
kernel = kernel.at[kernel_size // 2, indices].set(set_value)
kernel = rotate_2d(kernel, angle)
kernel = kernel / jnp.sum(kernel)
return kernel.astype(dtype)
[docs]
def motion_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
angle: float,
direction: int | float,
dtype: DType | None = None,
) -> HWArray:
"""Motion blur.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: motion kernel width and height. It should be odd and positive.
angle: angle of the motion blur in degrees (anti-clockwise rotation).
direction: direction of the motion blur in degrees (anti-clockwise rotation).
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
"""
dtype = dtype or image.dtype
kernel = calculate_motion_kernel_2d(kernel_size, angle, direction, dtype)
return filter_2d(image, kernel)
[docs]
def fft_motion_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
angle: float,
direction: int | float,
dtype: DType | None = None,
) -> HWArray:
"""Motion blur using FFT.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: motion kernel width and height. It should be odd and positive.
angle: angle of the motion blur in degrees (anti-clockwise rotation).
direction: direction of the motion blur in degrees (anti-clockwise rotation).
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
"""
dtype = dtype or image.dtype
kernel = calculate_motion_kernel_2d(kernel_size, angle, direction, dtype)
return fft_filter_2d(image, kernel)
def calculate_sobel_kernel_2d(dtype: DType):
"""Calculate sobel kernel."""
# used in separable manner
x0 = jnp.array([1, 2, 1])
x1 = jnp.array([1, 0, -1])
y0 = jnp.array([1, 0, -1])
y1 = jnp.array([1, 2, 1])
return jnp.stack([x0, x1, y0, y1], axis=0).astype(dtype)
[docs]
def sobel_2d(image: HWArray, dtype: DType | None = None) -> HWArray:
"""Sobel filter.
Args:
image: 2D input array. shape is ``(height, width)``.
dtype: data type of the kernel.
"""
dtype = dtype or image.dtype
kernel = calculate_sobel_kernel_2d(dtype)
gx0, gx1, gy0, gy1 = jnp.split(kernel, 4)
gx = filter_2d(filter_2d(image, gx0), gx1.T)
gy = filter_2d(filter_2d(image, gy0), gy1.T)
return jnp.sqrt(gx**2 + gy**2)
[docs]
def fft_sobel_2d(image: HWArray, dtype: DType | None = None) -> HWArray:
"""Sobel filter using FFT.
Args:
image: 2D input array. shape is ``(height, width)``.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
"""
dtype = dtype or image.dtype
kernel = calculate_sobel_kernel_2d(dtype)
gx0, gx1, gy0, gy1 = jnp.split(kernel, 4)
gx = fft_filter_2d(fft_filter_2d(image, gx0), gx1.T)
gy = fft_filter_2d(fft_filter_2d(image, gy0), gy1.T)
return jnp.sqrt(gx**2 + gy**2)
[docs]
def bilateral_blur_2d(
image: HWArray,
kernel_size: tuple[int, int],
sigma_space: tuple[float, float],
sigma_color: float,
dtype: DType | None = None,
):
"""Bilateral blur.
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
sigma_space: sigma of gaussian kernel.
sigma_color: sigma of gaussian kernel.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
"""
dtype = dtype or image.dtype
(ky, kx), (sy, sx) = kernel_size, sigma_space
center_index = (ky // 2, kx // 2)
gy = calculate_gaussian_kernel_1d(ky, sy, dtype=dtype)[None]
gx = calculate_gaussian_kernel_1d(kx, sx, dtype=dtype)[None]
space_kernel = gy.T @ gx
padding = delayed_canonicalize_padding(
image.shape,
padding="same",
kernel_size=kernel_size,
strides=(1, 1),
)
@ft.partial(
kernel_map,
shape=image.shape,
kernel_size=kernel_size,
strides=(1, 1),
padding=padding,
padding_mode=0,
)
def bilateral_blur_2d(array):
color_distance = (array - array[center_index]) ** 2
color_kernel = jnp.exp(-0.5 / sigma_color**2 * color_distance)
kernel = color_kernel * space_kernel
return jnp.sum(array * kernel) / jnp.sum(kernel)
return bilateral_blur_2d(image)
[docs]
def joint_bilateral_blur_2d(
image: HWArray,
guidance: HWArray,
kernel_size: tuple[int, int],
sigma_space: tuple[float, float],
sigma_color: float,
dtype: DType | None = None,
):
"""Joint bilateral blur.
Args:
image: 2D input array. shape is ``(height, width)``.
guidance: 2D guidance array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
sigma_space: sigma of gaussian kernel.
sigma_color: sigma of gaussian kernel.
dtype: data type of the kernel. Defaults to ``None`` to use the same
data type as ``array``.
"""
dtype = dtype or image.dtype
(ky, kx), (sy, sx) = kernel_size, sigma_space
center_index = (ky // 2, kx // 2)
gy = calculate_gaussian_kernel_1d(ky, sy, dtype=dtype)[None]
gx = calculate_gaussian_kernel_1d(kx, sx, dtype=dtype)[None]
space_kernel = gy.T @ gx
padding = delayed_canonicalize_padding(
image.shape,
padding="same",
kernel_size=kernel_size,
strides=(1, 1),
)
@ft.partial(
kernel_map,
shape=(2, *image.shape),
kernel_size=(2, *kernel_size),
strides=(1, 1, 1),
padding=((0, 0), *padding),
padding_mode=0,
)
def joint_bilateral_blur(array_guidance):
array, guidance = array_guidance
color_distance = (guidance - guidance[center_index]) ** 2
color_kernel = jnp.exp(-0.5 / sigma_color**2 * color_distance)
kernel = color_kernel * space_kernel
return jnp.sum(array * kernel) / jnp.sum(kernel)
return jnp.squeeze(joint_bilateral_blur(jnp.stack([image, guidance], axis=0)), 0)
def calculate_pascal_kernel_1d(kernel_size: int):
kernel = jnp.array([1.0], dtype=float)
stencil = jnp.array([1.0, 1.0], dtype=float)
for _ in range(1, kernel_size):
kernel = jnp.convolve(kernel, stencil, mode="full")
return kernel
[docs]
def blur_pool_2d(
image: HWArray,
kernel_size: tuple[int, int],
strides: tuple[int, int],
) -> HWArray:
"""Blur pooling see https://arxiv.org/abs/1904.11486
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
strides: stride of the convolution. Accepts tuple of two ints.
"""
ky, kx = kernel_size
py = calculate_pascal_kernel_1d(ky)
px = calculate_pascal_kernel_1d(kx)
kernel = jnp.outer(py, px)
kernel = kernel / jnp.sum(kernel) # normalize
return filter_2d(image, kernel, strides)
[docs]
def fft_blur_pool_2d(
image: HWArray,
kernel_size: tuple[int, int],
strides: tuple[int, int],
) -> HWArray:
"""Blur pooling see https://arxiv.org/abs/1904.11486
Args:
image: 2D input array. shape is ``(height, width)``.
kernel_size: size of the convolving kernel. Accepts tuple of two ints.
strides: stride of the convolution. Accepts tuple of two ints.
"""
ky, kx = kernel_size
py = calculate_pascal_kernel_1d(ky)
px = calculate_pascal_kernel_1d(kx)
kernel = jnp.outer(py, px)
kernel = kernel / jnp.sum(kernel) # normalize
return fft_filter_2d(image, kernel, strides)
class BaseAvgBlur2D(TreeClass):
def __init__(self, kernel_size: int | tuple[int, int]):
self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None)
args = (image, self.kernel_size)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)
spatial_ndim: int = 2
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))
[docs]
class AvgBlur2D(BaseAvgBlur2D):
"""Average blur 2D layer.
.. image:: ../_static/avgblur2d.png
Args:
kernel_size: size of the convolving kernel.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.AvgBlur2D(kernel_size=3)
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
"""
filter_op = staticmethod(avg_blur_2d)
[docs]
class FFTAvgBlur2D(BaseAvgBlur2D):
"""Average blur 2D layer using FFT.
.. image:: ../_static/avgblur2d.png
Args:
kernel_size: size of the convolving kernel.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.FFTAvgBlur2D(kernel_size=3)
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
"""
filter_op = staticmethod(fft_avg_blur_2d)
class BaseGaussianBlur2D(TreeClass):
def __init__(
self,
kernel_size: int | tuple[int, int],
*,
sigma: float | tuple[float, float] = 1.0,
):
self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size")
self.sigma = canonicalize(sigma, ndim=2, name="sigma")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None, None)
sigma = jax.lax.stop_gradient(self.sigma)
args = (image, self.kernel_size, sigma)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)
spatial_ndim: int = 2
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))
[docs]
class GaussianBlur2D(BaseGaussianBlur2D):
"""Apply Gaussian blur to a channel-first image.
.. image:: ../_static/gaussianblur2d.png
Args:
kernel_size: kernel size. accepts int or tuple of two ints.
sigma: sigma. Defaults to 1. accepts float or tuple of two floats.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.GaussianBlur2D(kernel_size=3)
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
"""
filter_op = staticmethod(gaussian_blur_2d)
[docs]
class FFTGaussianBlur2D(BaseGaussianBlur2D):
"""Apply Gaussian blur to a channel-first image using FFT.
.. image:: ../_static/gaussianblur2d.png
Args:
kernel_size: kernel size. accepts int or tuple of two ints.
sigma: sigma. Defaults to 1. accepts float or tuple of two floats.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.FFTGaussianBlur2D(kernel_size=3)
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
"""
filter_op = staticmethod(fft_gaussian_blur_2d)
[docs]
class UnsharpMask2D(BaseGaussianBlur2D):
"""Apply unsharp mask to a channel-first image.
.. image:: ../_static/unsharpmask2d.png
Args:
kernel_size: kernel size. accepts int or tuple of two ints.
sigma: sigma. Defaults to 1. accepts float or tuple of two floats.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.UnsharpMask2D(kernel_size=3)
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]
[1.2740686 1. 1. 1. 1.2740686]
[1.2740686 1. 1. 1. 1.2740686]
[1.2740686 1. 1. 1. 1.2740686]
[1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]]]
"""
filter_op = staticmethod(unsharp_mask_2d)
[docs]
class FFTUnsharpMask2D(BaseGaussianBlur2D):
"""Apply unsharp mask to a channel-first image using FFT.
.. image:: ../_static/unsharpmask2d.png
Args:
kernel_size: kernel size. accepts int or tuple of two ints.
sigma: sigma. Defaults to 1. accepts float or tuple of two floats.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.FFTUnsharpMask2D(kernel_size=3)
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]
[1.2740686 1. 1. 1. 1.2740686]
[1.2740686 1. 1. 1. 1.2740686]
[1.2740686 1. 1. 1. 1.2740686]
[1.4730237 1.2740686 1.2740686 1.2740686 1.4730237]]]
"""
filter_op = staticmethod(fft_unsharp_mask_2d)
class BoxBlur2DBase(TreeClass):
def __init__(self, kernel_size: int | tuple[int, int]):
self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None)
args = (image, self.kernel_size)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)
spatial_ndim: int = 2
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))
[docs]
class BoxBlur2D(BoxBlur2DBase):
"""Box blur 2D layer.
.. image:: ../_static/boxblur2d.png
Args:
kernel_size: size of the convolving kernel. Accepts int or tuple of two ints.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.BoxBlur2D((3, 5))
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]
[0.6 0.8 1. 0.8 0.6 ]
[0.6 0.8 1. 0.8 0.6 ]
[0.6 0.8 1. 0.8 0.6 ]
[0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]]]
"""
filter_op = staticmethod(box_blur_2d)
[docs]
class FFTBoxBlur2D(BoxBlur2DBase):
"""Box blur 2D layer using FFT.
.. image:: ../_static/boxblur2d.png
Args:
kernel_size: size of the convolving kernel. Accepts int or tuple of two ints.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.BoxBlur2D((3, 5))
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]
[0.6 0.8 1. 0.8 0.6 ]
[0.6 0.8 1. 0.8 0.6 ]
[0.6 0.8 1. 0.8 0.6 ]
[0.40000004 0.53333336 0.6666667 0.53333336 0.40000004]]]
"""
filter_op = staticmethod(fft_box_blur_2d)
class Laplacian2DBase(TreeClass):
def __init__(self, kernel_size: int | tuple[int, int]):
self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None)
args = (image, self.kernel_size)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)
spatial_ndim: int = 2
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))
[docs]
class Laplacian2D(Laplacian2DBase):
"""Apply Laplacian filter to a channel-first image.
.. image:: ../_static/laplacian2d.png
Args:
kernel_size: size of the convolving kernel. Accepts int or tuple of two ints.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.Laplacian2D(kernel_size=(3, 5))
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[-9. -7. -5. -7. -9.]
[-6. -3. 0. -3. -6.]
[-6. -3. 0. -3. -6.]
[-6. -3. 0. -3. -6.]
[-9. -7. -5. -7. -9.]]]
Note:
The laplacian considers all the neighbors of a pixel.
"""
filter_op = staticmethod(laplacian_2d)
[docs]
class FFTLaplacian2D(Laplacian2DBase):
"""Apply Laplacian filter to a channel-first image using FFT.
.. image:: ../_static/laplacian2d.png
Args:
kernel_size: size of the convolving kernel. Accepts int or tuple of two ints.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.FFTLaplacian2D(kernel_size=(3, 5))
>>> print(layer(jnp.ones((1, 5, 5)))) # doctest: +SKIP
[[[-9. -7. -5. -7. -9.]
[-6. -3. 0. -3. -6.]
[-6. -3. 0. -3. -6.]
[-6. -3. 0. -3. -6.]
[-9. -7. -5. -7. -9.]]]
Note:
The laplacian considers all the neighbors of a pixel.
"""
filter_op = staticmethod(fft_laplacian_2d)
class MotionBlur2DBase(TreeClass):
def __init__(
self,
kernel_size: int,
*,
angle: float = 0.0,
direction: float = 0.0,
):
self.kernel_size = kernel_size
self.angle = angle
self.direction = direction
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None, None, None)
angle, direction = jax.lax.stop_gradient((self.angle, self.direction))
args = (image, self.kernel_size, angle, direction)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)
spatial_ndim: int = 2
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))
[docs]
class MotionBlur2D(MotionBlur2DBase):
"""Apply motion blur to a channel-first image.
.. image:: ../_static/motionblur2d.png
Args:
kernel_size: motion kernel width and height. It should be odd and positive.
angle: angle of the motion blur in degrees (anti-clockwise rotation).
direction: direction of the motion blur.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4) + 0.0
>>> print(sk.image.MotionBlur2D(3, angle=30, direction=0.5)(x)) # doctest: +SKIP
[[[ 0.7827108 2.4696379 3.3715053 3.8119273]
[ 2.8356633 6.3387947 7.3387947 7.1810846]
[ 5.117592 10.338796 11.338796 10.550241 ]
[ 6.472714 10.020969 10.770187 9.100007 ]]]
"""
filter_op = staticmethod(motion_blur_2d)
[docs]
class FFTMotionBlur2D(MotionBlur2DBase):
"""Apply motion blur to a channel-first image using FFT.
.. image:: ../_static/motionblur2d.png
Args:
kernel_size: motion kernel width and height. It should be odd and positive.
angle: angle of the motion blur in degrees (anti-clockwise rotation).
direction: direction of the motion blur.
dtype: data type of the layer. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4) + 0.0
>>> print(sk.image.MotionBlur2D(3, angle=30, direction=0.5)(x)) # doctest: +SKIP
[[[ 0.7827108 2.4696379 3.3715053 3.8119273]
[ 2.8356633 6.3387947 7.3387947 7.1810846]
[ 5.117592 10.338796 11.338796 10.550241 ]
[ 6.472714 10.020969 10.770187 9.100007 ]]]
"""
filter_op = staticmethod(fft_motion_blur_2d)
class Sobel2DBase(TreeClass):
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return jax.vmap(self.filter_op)(image)
spatial_ndim: int = 2
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))
[docs]
class Sobel2D(Sobel2DBase):
"""Apply Sobel filter to a channel-first image.
.. image:: ../_static/sobel2d.png
Args:
dtype: data type of the layer. Defaults to ``float32``.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1,26).reshape(1, 5,5).astype(jnp.float32)
>>> layer = sk.image.Sobel2D()
>>> layer(x) # doctest: +SKIP
[[[21.954498, 28.635643, 32.55764 , 36.496574, 33.61547 ],
[41.036568, 40.792156, 40.792156, 40.792156, 46.8615 ],
[56.603886, 40.792156, 40.792156, 40.792156, 63.529522],
[74.323616, 40.792156, 40.792156, 40.792156, 81.706795],
[78.24321 , 68.26419 , 72.249565, 76.23647 , 89.27486 ]]]
"""
filter_op = staticmethod(sobel_2d)
[docs]
class FFTSobel2D(Sobel2DBase):
"""Apply Sobel filter to a channel-first image using FFT.
.. image:: ../_static/sobel2d.png
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1,26).reshape(1, 5,5).astype(jnp.float32)
>>> layer = sk.image.FFTSobel2D()
>>> layer(x) # doctest: +SKIP
[[[21.954498, 28.635643, 32.55764 , 36.496574, 33.61547 ],
[41.036568, 40.792156, 40.792156, 40.792156, 46.8615 ],
[56.603886, 40.792156, 40.792156, 40.792156, 63.529522],
[74.323616, 40.792156, 40.792156, 40.792156, 81.706795],
[78.24321 , 68.26419 , 72.249565, 76.23647 , 89.27486 ]]]
"""
filter_op = staticmethod(fft_sobel_2d)
class ElasticTransform2DBase(TreeClass):
def __init__(
self,
kernel_size: int | tuple[int, int],
sigma: float | tuple[float, float],
alpha: float | tuple[float, float],
):
self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size")
self.alpha = canonicalize(alpha, ndim=2, name="alpha")
self.sigma = canonicalize(sigma, ndim=2, name="sigma")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, *, key: jax.Array) -> CHWArray:
in_axes = (None, 0, None, None, None)
args = (image, self.kernel_size, self.sigma, self.alpha)
return jax.vmap(self.filter_op, in_axes=in_axes)(key, *args)
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))
spatial_ndim: int = 2
[docs]
class BilateralBlur2D(TreeClass):
"""Apply bilateral blur to a channel-first image.
.. image:: ../_static/bilateralblur2d.png
Args:
kernel_size: kernel size. accepts int or tuple of two ints.
sigma_space: sigma in the coordinate space. accepts float or tuple of two floats.
sigma_color: sigma in the color space. accepts float.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.ones([1, 5, 5])
>>> layer = sk.image.BilateralBlur2D((3, 5), sigma_space=(1.2, 1.3), sigma_color=1.5)
>>> print(layer(x)) # doctest: +SKIP
[[[0.5231399 0.6869784 0.75100434 0.6869784 0.5231399 ]
[0.70914114 0.9193193 1. 0.9193192 0.70914114]
[0.70914114 0.9193193 1. 0.9193192 0.70914114]
[0.70914114 0.9193193 1. 0.9193192 0.70914114]
[0.5231399 0.6869784 0.75100434 0.6869784 0.5231399 ]]]
"""
def __init__(
self,
kernel_size: int | tuple[int, int],
*,
sigma_space: float | tuple[float, float],
sigma_color: float,
):
self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size")
self.sigma_space = canonicalize(sigma_space, ndim=2, name="sigma_space")
self.sigma_color = sigma_color
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None, None, None)
args = (self.kernel_size, self.sigma_space, self.sigma_color)
return jax.vmap(bilateral_blur_2d, in_axes=in_axes)(image, *args)
spatial_ndim: int = 2
[docs]
class JointBilateralBlur2D(TreeClass):
"""Apply joint bilateral blur to a channel-first image.
.. image:: ../_static/jointbilateralblur2d.png
Args:
kernel_size: kernel size. accepts int or tuple of two ints.
sigma_space: sigma in the coordinate space. accepts float or tuple of two floats.
sigma_color: sigma in the color space. accepts float.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.ones([1, 5, 5])
>>> guide = jnp.ones([1, 5, 5])
>>> layer = sk.image.JointBilateralBlur2D((3, 5), sigma_space=(1.2, 1.3), sigma_color=1.5)
>>> print(layer(x, guide)) # doctest: +SKIP
[[[0.5231399 0.6869784 0.75100434 0.6869784 0.5231399 ]
[0.70914114 0.9193193 1. 0.9193192 0.70914114]
[0.70914114 0.9193193 1. 0.9193192 0.70914114]
[0.70914114 0.9193193 1. 0.9193192 0.70914114]
[0.5231399 0.6869784 0.75100434 0.6869784 0.5231399 ]]]
"""
def __init__(
self,
kernel_size: int | tuple[int, int],
*,
sigma_space: float | tuple[float, float],
sigma_color: float,
):
self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size")
self.sigma_space = canonicalize(sigma_space, ndim=2, name="sigma_space")
self.sigma_color = sigma_color
[docs]
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray, guide: CHWArray) -> CHWArray:
"""Apply joint bilateral blur to a channel-first image.
Args:
image: input image.
guide: guide image used for computing the gaussian for color space.
"""
in_axes = (0, 0, None, None, None)
args = (self.kernel_size, self.sigma_space, self.sigma_color)
return jax.vmap(joint_bilateral_blur_2d, in_axes=in_axes)(image, guide, *args)
spatial_ndim: int = 2
class BlurPool2DBase(TreeClass):
def __init__(
self,
kernel_size: int | tuple[int, int],
strides: int | tuple[int, int],
):
self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size")
self.strides = canonicalize(strides, ndim=2, name="strides")
@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
in_axes = (0, None, None)
args = (image, self.kernel_size, self.strides)
return jax.vmap(self.filter_op, in_axes=in_axes)(*args)
spatial_ndim: int = 2
filter_op = staticmethod(abc.abstractmethod(lambda _: ...))
[docs]
class BlurPool2D(BlurPool2DBase):
"""Blur and downsample a channel-first image.
.. image:: ../_static/blurpool2d.png
Args:
kernel_size: kernel size. accepts int or tuple of two ints.
strides: strides. accepts int or tuple of two ints.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> layer = sk.image.BlurPool2D(kernel_size=3, strides=2)
>>> print(layer(x)) # doctest: +SKIP
[[[ 1.6875 3.5 3.5625]
[ 8.5 13. 11. ]
[11.0625 16. 12.9375]]]
"""
filter_op = staticmethod(blur_pool_2d)
[docs]
class FFTBlurPool2D(BlurPool2DBase):
"""Blur and downsample a channel-first image using FFT.
.. image:: ../_static/blurpool2d.png
Args:
kernel_size: kernel size. accepts int or tuple of two ints.
strides: stride. accepts int or tuple of two ints.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> layer = sk.image.FFTBlurPool2D(kernel_size=3, strides=2)
>>> print(layer(x)) # doctest: +SKIP
[[[ 1.6875 3.5 3.5625]
[ 8.5 13. 11. ]
[11.0625 16. 12.9375]]]
"""
filter_op = staticmethod(fft_blur_pool_2d)