# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import inspect
from collections.abc import Callable as ABCCallable
from typing import Callable, Union, get_args
import jax
import jax.numpy as jnp
from jax import lax
from serket import TreeClass, autoinit, field
from serket._src.utils.typing import ActivationLiteral
from serket._src.utils.validate import IsInstance, Range, ScalarLike
[docs]
@autoinit
class CeLU(TreeClass):
"""Celu activation function"""
alpha: float = field(
default=1.0,
on_setattr=[ScalarLike()],
on_getattr=[lax.stop_gradient_p.bind],
)
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.celu(input, alpha=self.alpha)
[docs]
@autoinit
class ELU(TreeClass):
"""Exponential linear unit"""
alpha: float = field(
default=1.0,
on_setattr=[ScalarLike()],
on_getattr=[lax.stop_gradient_p.bind],
)
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.elu(input, alpha=self.alpha)
[docs]
@autoinit
class GELU(TreeClass):
"""Gaussian error linear unit"""
approximate: bool = field(default=False, on_setattr=[IsInstance(bool)])
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.gelu(input, approximate=self.approximate)
[docs]
@autoinit
class GLU(TreeClass):
"""Gated linear unit"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.glu(input)
def hard_shrink(input: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array:
"""Hard shrink activation function"""
return jnp.where(input > alpha, input, jnp.where(input < -alpha, input, 0.0))
[docs]
@autoinit
class HardShrink(TreeClass):
"""Hard shrink activation function"""
alpha: float = field(
default=0.5,
on_setattr=[Range(0), ScalarLike()],
on_getattr=[lax.stop_gradient_p.bind],
)
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return hard_shrink(input, self.alpha)
[docs]
class HardSigmoid(TreeClass):
"""Hard sigmoid activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.hard_sigmoid(input)
[docs]
class HardSwish(TreeClass):
"""Hard swish activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.hard_swish(input)
[docs]
class HardTanh(TreeClass):
"""Hard tanh activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.hard_tanh(input)
[docs]
class LogSigmoid(TreeClass):
"""Log sigmoid activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.log_sigmoid(input)
[docs]
class LogSoftmax(TreeClass):
"""Log softmax activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.log_softmax(input)
[docs]
@autoinit
class LeakyReLU(TreeClass):
"""Leaky ReLU activation function"""
negative_slope: float = field(
default=0.01,
on_setattr=[Range(0), ScalarLike()],
on_getattr=[lax.stop_gradient_p.bind],
)
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.leaky_relu(input, self.negative_slope)
[docs]
class ReLU(TreeClass):
"""ReLU activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.relu(input)
[docs]
class ReLU6(TreeClass):
"""ReLU6 activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.relu6(input)
[docs]
class SeLU(TreeClass):
"""Scaled Exponential Linear Unit"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.selu(input)
[docs]
class Sigmoid(TreeClass):
"""Sigmoid activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.sigmoid(input)
[docs]
class SoftPlus(TreeClass):
"""SoftPlus activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.softplus(input)
def softsign(x: jax.typing.ArrayLike) -> jax.Array:
"""SoftSign activation function"""
return x / (1 + jnp.abs(x))
[docs]
class SoftSign(TreeClass):
"""SoftSign activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return softsign(input)
def softshrink(input: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array:
"""Soft shrink activation function"""
return jnp.where(
input < -alpha,
input + alpha,
jnp.where(input > alpha, input - alpha, 0.0),
)
[docs]
@autoinit
class SoftShrink(TreeClass):
"""SoftShrink activation function"""
alpha: float = field(
default=0.5,
on_setattr=[Range(0), ScalarLike()],
on_getattr=[lax.stop_gradient_p.bind],
)
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return softshrink(input, self.alpha)
def squareplus(input: jax.typing.ArrayLike) -> jax.Array:
"""SquarePlus activation function"""
return 0.5 * (input + jnp.sqrt(input * input + 4))
[docs]
class SquarePlus(TreeClass):
"""SquarePlus activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return squareplus(input)
[docs]
class Swish(TreeClass):
"""Swish activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.swish(input)
[docs]
class Tanh(TreeClass):
"""Tanh activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return jax.nn.tanh(input)
def tanh_shrink(input: jax.typing.ArrayLike) -> jax.Array:
"""TanhShrink activation function"""
return input - jnp.tanh(input)
[docs]
class TanhShrink(TreeClass):
"""TanhShrink activation function"""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return tanh_shrink(input)
def thresholded_relu(input: jax.typing.ArrayLike, theta: float = 1.0) -> jax.Array:
"""Thresholded ReLU activation function
Reference:
https://arxiv.org/pdf/1911.09737.pdf.
"""
return jnp.where(input > theta, input, 0)
[docs]
@autoinit
class ThresholdedReLU(TreeClass):
"""Thresholded ReLU activation function."""
theta: float = field(
default=1.0,
on_setattr=[Range(0), ScalarLike()],
on_getattr=[lax.stop_gradient_p.bind],
)
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return thresholded_relu(input, self.theta)
def mish(input: jax.typing.ArrayLike) -> jax.Array:
"""Mish activation function https://arxiv.org/pdf/1908.08681.pdf."""
return input * jax.nn.tanh(jax.nn.softplus(input))
[docs]
class Mish(TreeClass):
"""Mish activation function https://arxiv.org/pdf/1908.08681.pdf."""
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return mish(input)
def prelu(input: jax.typing.ArrayLike, a: float = 0.25) -> jax.Array:
"""Parametric ReLU activation function"""
return jnp.where(input >= 0, input, input * a)
[docs]
@autoinit
class PReLU(TreeClass):
"""Parametric ReLU activation function"""
a: float = field(default=0.25, on_setattr=[Range(0), ScalarLike()])
[docs]
def __call__(self, input: jax.Array) -> jax.Array:
return prelu(input, self.a)
# useful for building layers from configuration text
acts = [
jax.nn.celu,
jax.nn.elu,
jax.nn.gelu,
jax.nn.glu,
hard_shrink,
jax.nn.hard_sigmoid,
jax.nn.hard_swish,
jax.nn.hard_tanh,
jax.nn.leaky_relu,
jax.nn.log_sigmoid,
jax.nn.log_softmax,
mish,
prelu,
jax.nn.relu,
jax.nn.relu6,
jax.nn.selu,
jax.nn.sigmoid,
jax.nn.softplus,
softshrink,
softsign,
squareplus,
jax.nn.swish,
jax.nn.tanh,
tanh_shrink,
thresholded_relu,
]
ActivationFunctionType = Callable[[jax.typing.ArrayLike], jax.Array]
ActivationType = Union[ActivationLiteral, ActivationFunctionType]
act_map = dict(zip(get_args(ActivationLiteral), acts))
def resolve_act(act):
if isinstance(act, ABCCallable):
assert len(inspect.getfullargspec(act).args) == 1
return act
if isinstance(act, str):
try:
return jax.tree_util.tree_map(lambda x: x, act_map[act])
except KeyError:
raise ValueError(f"Unknown {act=}, available activations: {list(act_map)}")
raise TypeError(f"Unknown activation type {type(act)}.")