# Copyright 2024 serket authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines RNN related classes."""
from __future__ import annotations
import abc
import functools as ft
from typing import Any, Callable
import jax
import jax.numpy as jnp
import jax.random as jr
from serket import TreeClass, autoinit
from serket._src.custom_transform import tree_state
from serket._src.nn.activation import ActivationType, resolve_act
from serket._src.nn.convolution import (
Conv1D,
Conv2D,
Conv3D,
FFTConv1D,
FFTConv2D,
FFTConv3D,
)
from serket._src.nn.linear import Linear
from serket._src.utils.lazy import maybe_lazy_call, maybe_lazy_init
from serket._src.utils.typing import (
DilationType,
DType,
InitType,
KernelSizeType,
PaddingType,
S,
StridesType,
)
from serket._src.utils.validate import (
validate_in_features_shape,
validate_pos_int,
validate_spatial_ndim,
)
def is_lazy_call(instance, *_1, **_2) -> bool:
return instance.in_features is None
def is_lazy_init(_, in_features: int | None, *_1, **_2) -> bool:
return in_features is None
def infer_in_features(_, input: jax.Array, *_1, **_2) -> int:
return input.shape[0]
updates = dict(in_features=infer_in_features)
@autoinit
class RNNState(TreeClass):
hidden_state: jax.Array
class SimpleRNNState(RNNState): ...
[docs]
class SimpleRNNCell(TreeClass):
"""Vanilla RNN cell that defines the update rule for the hidden state
Args:
in_features: the number of input features
hidden_features: the number of hidden features
key: the key to use to initialize the weights
weight_init: the function to use to initialize the weights
bias_init: the function to use to initialize the bias
recurrent_weight_init: the function to use to initialize the recurrent weights
act: the activation function to use for the hidden state update
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> # 10-dimensional input, 20-dimensional hidden state
>>> cell = sk.nn.SimpleRNNCell(10, 20, key=jr.key(0))
>>> # 20-dimensional hidden state
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(cell)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape # 20 features
(20,)
Note:
:class:`.SimpleRNNCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.SimpleRNNCell(None, 20, key=jr.key(0))
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(lazy)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(20,)
Reference:
- https://www.tensorflow.org/api_docs/python/tf/keras/layers/SimpleRNNCell.
"""
@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
self,
in_features: int,
hidden_features: int,
*,
key: jax.Array,
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
recurrent_weight_init: InitType = "orthogonal",
act: ActivationType = jax.nn.tanh,
dtype: DType = jnp.float32,
):
k1, k2 = jr.split(key, 2)
self.in_features = validate_pos_int(in_features)
self.hidden_features = validate_pos_int(hidden_features)
self.act = resolve_act(act)
i2h = Linear(
in_features,
hidden_features,
weight_init=weight_init,
bias_init=bias_init,
key=k1,
dtype=dtype,
)
h2h = Linear(
hidden_features,
hidden_features,
weight_init=recurrent_weight_init,
bias_init=None,
key=k2,
dtype=dtype,
)
self.in_hidden_to_hidden = Linear(
in_features=in_features + hidden_features,
out_features=hidden_features,
weight_init=lambda *_: jnp.concatenate([i2h.weight, h2h.weight], axis=-1),
bias_init=lambda *_: i2h.bias,
dtype=dtype,
key=k1, # dummy key
out_axis=0,
)
[docs]
@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
@ft.partial(validate_spatial_ndim, argnum=0)
@ft.partial(validate_in_features_shape, axis=0)
def __call__(
self,
input: jax.Array,
state: SimpleRNNState,
) -> tuple[jax.Array, SimpleRNNState]:
if not isinstance(state, SimpleRNNState):
raise TypeError(f"Expected {state=} to be an instance of `SimpleRNNState`")
ih = jnp.concatenate([input, state.hidden_state], axis=-1)
h = self.in_hidden_to_hidden(ih)
h = self.act(h)
return h, SimpleRNNState(h)
spatial_ndim: int = 0
class LinearState(RNNState): ...
[docs]
class LinearCell(TreeClass):
"""No hidden state cell that applies a dense(Linear+activation) layer to the input
Args:
in_features: the number of input features
hidden_features: the number of hidden features
key: the key to use to initialize the weights
weight_init: the function to use to initialize the weights
bias_init: the function to use to initialize the bias
act: the activation function to use for the hidden state update,
use `None` for no activation
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> # 10-dimensional input, 20-dimensional hidden state
>>> cell = sk.nn.LinearCell(10, 20, key=jr.key(0))
>>> # 20-dimensional hidden state
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(cell)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape # 20 features
(20,)
Note:
:class:`.LinearCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.LinearCell(None, 20, key=jr.key(0))
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(lazy)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(20,)
"""
@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
self,
in_features: int,
hidden_features: int,
*,
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
act: ActivationType = jax.nn.tanh,
key: jax.Array,
dtype: DType = jnp.float32,
):
self.in_features = validate_pos_int(in_features)
self.hidden_features = validate_pos_int(hidden_features)
self.act = resolve_act(act)
self.in_to_hidden = Linear(
in_features,
hidden_features,
weight_init=weight_init,
bias_init=bias_init,
key=key,
dtype=dtype,
out_axis=0,
)
[docs]
@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
@ft.partial(validate_spatial_ndim, argnum=0)
@ft.partial(validate_in_features_shape, axis=0)
def __call__(
self,
input: jax.Array,
state: LinearState,
) -> tuple[jax.Array, LinearState]:
if not isinstance(state, LinearState):
raise TypeError(f"Expected {state=} to be an instance of `LinearState`")
h = self.in_to_hidden(input)
h = self.act(h)
return h, LinearState(h)
spatial_ndim: int = 0
@autoinit
class LSTMState(RNNState):
cell_state: jax.Array
[docs]
class LSTMCell(TreeClass):
"""LSTM cell that defines the update rule for the hidden state and cell state
Args:
in_features: the number of input features
hidden_features: the number of hidden features
weight_init: the function to use to initialize the weights
bias_init: the function to use to initialize the bias
recurrent_weight_init: the function to use to initialize the recurrent weights
act: the activation function to use for the hidden state update
recurrent_act: the activation function to use for the cell state update
key: the key to use to initialize the weights
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> # 10-dimensional input, 20-dimensional hidden state
>>> cell = sk.nn.LSTMCell(10, 20, key=jr.key(0))
>>> # 20-dimensional hidden state
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(cell)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape # 20 features
(20,)
Note:
:class:`.LSTMCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.LSTMCell(None, 20, key=jr.key(0))
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(lazy)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(20,)
Reference:
- https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTMCell
- https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/recurrent.py
"""
@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
self,
in_features: int,
hidden_features: int,
*,
key: jax.Array,
weight_init: str | Callable = "glorot_uniform",
bias_init: str | Callable | None = "zeros",
recurrent_weight_init: str | Callable = "orthogonal",
act: str | Callable[[Any], Any] | None = "tanh",
recurrent_act: ActivationType | None = "sigmoid",
dtype: DType = jnp.float32,
):
k1, k2 = jr.split(key, 2)
self.in_features = validate_pos_int(in_features)
self.hidden_features = validate_pos_int(hidden_features)
self.act = resolve_act(act)
self.recurrent_act = resolve_act(recurrent_act)
i2h = Linear(
in_features,
hidden_features * 4,
weight_init=weight_init,
bias_init=bias_init,
key=k1,
dtype=dtype,
)
h2h = Linear(
hidden_features,
hidden_features * 4,
weight_init=recurrent_weight_init,
bias_init=None,
key=k2,
dtype=dtype,
)
self.in_hidden_to_hidden = Linear(
in_features=in_features + hidden_features,
out_features=hidden_features * 4,
weight_init=lambda *_: jnp.concatenate([i2h.weight, h2h.weight], axis=-1),
bias_init=lambda *_: i2h.bias,
dtype=dtype,
key=k1, # dummy key
out_axis=0,
)
[docs]
@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
@ft.partial(validate_spatial_ndim, argnum=0)
@ft.partial(validate_in_features_shape, axis=0)
def __call__(
self,
input: jax.Array,
state: LSTMState,
) -> tuple[jax.Array, LSTMState]:
if not isinstance(state, LSTMState):
raise TypeError(f"Expected {state=} to be an instance of `LSTMState`")
h, c = state.hidden_state, state.cell_state
ih = jnp.concatenate([input, h], axis=-1)
h = self.in_hidden_to_hidden(ih)
i, f, g, o = jnp.split(h, 4)
i = self.recurrent_act(i)
f = self.recurrent_act(f)
g = self.act(g)
o = self.recurrent_act(o)
c = f * c + i * g
h = o * self.act(c)
return h, LSTMState(h, c)
spatial_ndim: int = 0
class GRUState(RNNState): ...
[docs]
class GRUCell(TreeClass):
"""GRU cell that defines the update rule for the hidden state and cell state
Args:
in_features: the number of input features
hidden_features: the number of hidden features
key: the key to use to initialize the weights
weight_init: the function to use to initialize the weights
bias_init: the function to use to initialize the bias
recurrent_weight_init: the function to use to initialize the recurrent weights
act: the activation function to use for the hidden state update
recurrent_act: the activation function to use for the cell state update
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> # 10-dimensional input, 20-dimensional hidden state
>>> cell = sk.nn.GRUCell(10, 20, key=jr.key(0))
>>> # 20-dimensional hidden state
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(cell)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape # 20 features
(20,)
Note:
:class:`.GRUCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.GRUCell(None, 20, key=jr.key(0))
>>> input = jnp.ones(10) # 10 features
>>> state = sk.tree_state(lazy)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(20,)
Reference:
- https://keras.io/api/layers/recurrent_layers/gru/
"""
@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
self,
in_features: int,
hidden_features: int,
*,
key: jax.Array,
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
recurrent_weight_init: InitType = "orthogonal",
act: ActivationType | None = "tanh",
recurrent_act: ActivationType | None = "sigmoid",
dtype: DType = jnp.float32,
):
k1, k2 = jr.split(key, 2)
self.in_features = validate_pos_int(in_features)
self.hidden_features = validate_pos_int(hidden_features)
self.act = resolve_act(act)
self.recurrent_act = resolve_act(recurrent_act)
self.in_to_hidden = Linear(
in_features,
hidden_features * 3,
weight_init=weight_init,
bias_init=bias_init,
key=k1,
dtype=dtype,
out_axis=0,
)
self.hidden_to_hidden = Linear(
hidden_features,
hidden_features * 3,
weight_init=recurrent_weight_init,
bias_init=None,
key=k2,
dtype=dtype,
out_axis=0,
)
[docs]
@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
@ft.partial(validate_spatial_ndim, argnum=0)
@ft.partial(validate_in_features_shape, axis=0)
def __call__(
self,
input: jax.Array,
state: GRUState,
) -> tuple[jax.Array, GRUState]:
if not isinstance(state, GRUState):
raise TypeError(f"Expected {state=} to be an instance of `GRUState`")
h = state.hidden_state
xe, xu, xo = jnp.split(self.in_to_hidden(input), 3)
he, hu, ho = jnp.split(self.hidden_to_hidden(h), 3)
e = self.recurrent_act(xe + he)
u = self.recurrent_act(xu + hu)
o = self.act(xo + (e * ho))
h = (1 - u) * o + u * h
return h, GRUState(hidden_state=h)
spatial_ndim: int = 0
@autoinit
class ConvLSTMNDState(RNNState):
cell_state: jax.Array
class ConvLSTMNDCell(TreeClass):
@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
self,
in_features: int,
hidden_features: int,
kernel_size: KernelSizeType,
*,
key: jax.Array,
strides: StridesType = 1,
padding: PaddingType = "same",
dilation: DilationType = 1,
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
recurrent_weight_init: InitType = "orthogonal",
act: ActivationType | None = "tanh",
recurrent_act: ActivationType | None = "hard_sigmoid",
dtype: DType = jnp.float32,
):
k1, k2 = jr.split(key, 2)
self.in_features = validate_pos_int(in_features)
self.hidden_features = validate_pos_int(hidden_features)
self.act = resolve_act(act)
self.recurrent_act = resolve_act(recurrent_act)
self.in_to_hidden = self.conv_layer(
in_features,
hidden_features * 4,
kernel_size,
strides=strides,
padding=padding,
dilation=dilation,
weight_init=weight_init,
bias_init=bias_init,
key=k1,
dtype=dtype,
)
self.hidden_to_hidden = self.conv_layer(
hidden_features,
hidden_features * 4,
kernel_size,
strides=strides,
padding=padding,
dilation=dilation,
weight_init=recurrent_weight_init,
bias_init=None,
key=k2,
dtype=dtype,
)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
@ft.partial(validate_spatial_ndim, argnum=0)
@ft.partial(validate_in_features_shape, axis=0)
def __call__(
self,
input: jax.Array,
state: ConvLSTMNDState,
) -> tuple[jax.Array, ConvLSTMNDState]:
if not isinstance(state, ConvLSTMNDState):
raise TypeError(f"Expected {state=} to be an instance of ConvLSTMNDState.")
h, c = state.hidden_state, state.cell_state
h = self.in_to_hidden(input) + self.hidden_to_hidden(h)
i, f, g, o = jnp.split(h, 4, axis=0)
i = self.recurrent_act(i)
f = self.recurrent_act(f)
g = self.act(g)
o = self.recurrent_act(o)
c = f * c + i * g
h = o * self.act(c)
return h, ConvLSTMNDState(h, c)
@property
@abc.abstractmethod
def conv_layer(self): ...
spatial_ndim = property(abc.abstractmethod(lambda _: ...))
[docs]
class ConvLSTM1DCell(ConvLSTMNDCell):
"""1D Convolution LSTM cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: PRNG key
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.ConvLSTM1DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4)
Note:
:class:`.ConvLSTM1DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.ConvLSTM1DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4)) # time, in_features, spatial dimensions
>>> state = sk.tree_state(lazy, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4)
Reference:
https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D
"""
spatial_ndim: int = 1
conv_layer = Conv1D
[docs]
class FFTConvLSTM1DCell(ConvLSTMNDCell):
"""1D FFT Convolution LSTM cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: PRNG key
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.FFTConvLSTM1DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4)
Note:
:class:`.FFTConvLSTM1DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.FFTConvLSTM1DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4)
Reference:
- https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D
"""
spatial_ndim: int = 1
conv_layer = FFTConv1D
[docs]
class ConvLSTM2DCell(ConvLSTMNDCell):
"""2D Convolution LSTM cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to use to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.ConvLSTM2DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4, 4)
Note:
:class:`.ConvLSTM2DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.ConvLSTM2DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(lazy, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4, 4)
Reference:
- https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM2D
"""
spatial_ndim: int = 2
conv_layer = Conv2D
[docs]
class FFTConvLSTM2DCell(ConvLSTMNDCell):
"""2D FFT Convolution LSTM cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.FFTConvLSTM2DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4, 4)
Note:
:class:`.FFTConvLSTM2DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.FFTConvLSTM2DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4)) # time, in_features, spatial dimensions
>>> state = sk.tree_state(lazy, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4, 4)
Reference:
- https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM2D
"""
spatial_ndim: int = 2
conv_layer = FFTConv2D
[docs]
class ConvLSTM3DCell(ConvLSTMNDCell):
"""3D Convolution LSTM cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.ConvLSTM3DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4, 4, 4)
Note:
:class:`.ConvLSTM3DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.ConvLSTM3DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4, 4, 4)
Reference:
- https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM3D
"""
spatial_ndim: int = 3
conv_layer = Conv3D
[docs]
class FFTConvLSTM3DCell(ConvLSTMNDCell):
"""3D FFT Convolution LSTM cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.FFTConvLSTM3DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4, 4, 4)
Note:
:class:`.FFTConvLSTM3DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.FFTConvLSTM3DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4, 4, 4)
Reference:
- https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM3D
"""
spatial_ndim: int = 3
conv_layer = FFTConv3D
class ConvGRUNDState(RNNState): ...
class ConvGRUNDCell(TreeClass):
@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
self,
in_features: int,
hidden_features: int,
kernel_size: int | tuple[int, ...],
*,
key: jax.Array,
strides: StridesType = 1,
padding: PaddingType = "same",
dilation: DilationType = 1,
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
recurrent_weight_init: InitType = "orthogonal",
act: ActivationType | None = "tanh",
recurrent_act: ActivationType | None = "sigmoid",
dtype: DType = jnp.float32,
):
k1, k2 = jr.split(key, 2)
self.in_features = validate_pos_int(in_features)
self.hidden_features = validate_pos_int(hidden_features)
self.act = resolve_act(act)
self.recurrent_act = resolve_act(recurrent_act)
self.in_to_hidden = self.conv_layer(
in_features,
hidden_features * 3,
kernel_size,
strides=strides,
padding=padding,
dilation=dilation,
weight_init=weight_init,
bias_init=bias_init,
key=k1,
dtype=dtype,
)
self.hidden_to_hidden = self.conv_layer(
hidden_features,
hidden_features * 3,
kernel_size,
strides=strides,
padding=padding,
dilation=dilation,
weight_init=recurrent_weight_init,
bias_init=None,
key=k2,
dtype=dtype,
)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
@ft.partial(validate_spatial_ndim, argnum=0)
@ft.partial(validate_in_features_shape, axis=0)
def __call__(
self,
input: jax.Array,
state: ConvGRUNDState,
) -> tuple[jax.Array, ConvGRUNDState]:
if not isinstance(state, ConvGRUNDState):
raise TypeError(f"Expected {state=} to be an instance of `GRUState`")
h = state.hidden_state
xe, xu, xo = jnp.split(self.in_to_hidden(input), 3)
he, hu, ho = jnp.split(self.hidden_to_hidden(h), 3)
e = self.recurrent_act(xe + he)
u = self.recurrent_act(xu + hu)
o = self.act(xo + (e * ho))
h = (1 - u) * o + u * h
return h, ConvGRUNDState(h)
@property
@abc.abstractmethod
def conv_layer(self): ...
spatial_ndim = property(abc.abstractmethod(lambda _: ...))
[docs]
class ConvGRU1DCell(ConvGRUNDCell):
"""1D Convolution GRU cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.ConvGRU1DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4)
Note:
:class:`.ConvGRU1DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.ConvGRU1DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4)
"""
spatial_ndim: int = 1
conv_layer = Conv1D
[docs]
class FFTConvGRU1DCell(ConvGRUNDCell):
"""1D FFT Convolution GRU cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.FFTConvGRU1DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4)) # time, in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4)
Note:
:class:`.FFTConvGRU1DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.FFTConvGRU1DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4)) # time, in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4)
"""
spatial_ndim: int = 1
conv_layer = FFTConv1D
[docs]
class ConvGRU2DCell(ConvGRUNDCell):
"""2D Convolution GRU cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.ConvGRU2DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4, 4)
Note:
:class:`.ConvGRU2DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.ConvGRU2DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4, 4)
"""
spatial_ndim: int = 2
conv_layer = Conv2D
[docs]
class FFTConvGRU2DCell(ConvGRUNDCell):
"""2D FFT Convolution GRU cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.FFTConvGRU2DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4, 4)
Note:
:class:`.FFTConvGRU2DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.FFTConvGRU2DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4)) # time, in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4, 4)
"""
spatial_ndim: int = 2
conv_layer = FFTConv2D
[docs]
class ConvGRU3DCell(ConvGRUNDCell):
"""3D Convolution GRU cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.ConvGRU3DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4, 4, 4)
Note:
:class:`.ConvGRU3DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.ConvGRU3DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4, 4)) # time, in_features, spatial dimensions
>>> state = sk.tree_state(lazy, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4, 4, 4)
"""
spatial_ndim: int = 3
conv_layer = Conv3D
[docs]
class FFTConvGRU3DCell(ConvGRUNDCell):
"""3D Convolution GRU cell that defines the update rule for the hidden state and cell state
Args:
in_features: Number of input features
hidden_features: Number of output features
key: random key to initialize weights.
kernel_size: Size of the convolutional kernel
strides: Stride of the convolution
padding: Padding of the convolution
dilation: Dilation of the convolutional kernel
weight_init: Weight initialization function
bias_init: Bias initialization function
recurrent_weight_init: Recurrent weight initialization function
act: Activation function
recurrent_act: Recurrent activation function
dtype: dtype of the weights and biases. ``float32``
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> cell = sk.nn.FFTConvGRU3DCell(10, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4, 4)) # in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> output, state = cell(input, state)
>>> state.hidden_state.shape
(2, 4, 4, 4)
Note:
:class:`.FFTConvGRU3DCell` supports lazy initialization, meaning that the
weights and biases are not initialized until the first call to the layer.
This is useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use :func:`.value_and_tree` to call the layer and return the method
output and the material layer.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> lazy = sk.nn.FFTConvGRU3DCell(None, 2, 3, key=jr.key(0))
>>> input = jnp.ones((10, 4, 4, 4)) # time, in_features, spatial dimensions
>>> state = sk.tree_state(cell, input=input)
>>> _, material = sk.value_and_tree(lambda cell: cell(input, state))(cell)
>>> output, state = material(input, state)
>>> state.hidden_state.shape
(2, 4, 4, 4)
"""
spatial_ndim: int = 3
conv_layer = FFTConv3D
[docs]
def scan_cell(
cell,
in_axis: int = 0,
out_axis: int = 0,
reverse: bool = False,
) -> Callable[[jax.Array, S], tuple[jax.Array, S]]:
"""Scan am RNN cell over a sequence.
Args:
cell: the RNN cell to scan. The cell should have the following signature:
`cell(input, state) -> tuple[output, state]`
in_axis: the axis to scan over. Defaults to 0.
out_axis: the axis to move the output to. Defaults to 0.
reverse: whether to scan the sequence in reverse order. Defaults to ``False``.
Example:
Unidirectional RNN:
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> key = jr.key(0)
>>> cell = sk.nn.SimpleRNNCell(1, 2, key=key)
>>> state = sk.tree_state(cell)
>>> input = jnp.ones([10, 1])
>>> output, state = sk.nn.scan_cell(cell)(input, state)
>>> print(output.shape)
(10, 2)
Example:
Bidirectional RNN:
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> k1, k2 = jr.split(jr.key(0))
>>> cell1 = sk.nn.SimpleRNNCell(1, 2, key=k1)
>>> cell2 = sk.nn.SimpleRNNCell(1, 2, key=k2)
>>> state1, state2 = sk.tree_state((cell1, cell2))
>>> input = jnp.ones([10, 1])
>>> output1, state1 = sk.nn.scan_cell(cell1)(input, state1)
>>> output2, state2 = sk.nn.scan_cell(cell2, reverse=True)(input, state2)
>>> output = jnp.concatenate((output1, output2), axis=1)
>>> print(output.shape)
(10, 4)
Example:
Combining multiple RNN cells:
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import numpy.testing as npt
>>> k1, k2 = jr.split(jr.key(0))
>>> cell1 = sk.nn.LSTMCell(1, 2, bias_init=None, key=k1)
>>> cell2 = sk.nn.LSTMCell(2, 1, bias_init=None, key=k2)
>>> def cell(input, state):
... state1, state2 = state
... output, state1 = cell1(input, state1)
... output, state2 = cell2(output, state2)
... return output, (state1, state2)
>>> state = sk.tree_state((cell1, cell2))
>>> input = jnp.ones([2, 1])
>>> output1, state = sk.nn.scan_cell(cell)(input, state)
<BLANKLINE>
>>> # This is equivalent to:
>>> state1, state2 = sk.tree_state((cell1, cell2))
>>> output2 = jnp.zeros([2, 1])
>>> # first step
>>> output, state1 = cell1(input[0], state1)
>>> output, state2 = cell2(output, state2)
>>> output2 = output2.at[0].set(output)
>>> # second step
>>> output, state1 = cell1(input[1], state1)
>>> output, state2 = cell2(output, state2)
>>> output2 = output2.at[1].set(output)
>>> npt.assert_allclose(output1, output2, atol=1e-6)
"""
def scan_func(state: S, input: jax.Array) -> tuple[S, jax.Array]:
output, state = cell(input, state)
return state, output
def wrapper(input: jax.Array, state: S) -> tuple[jax.Array, S]:
# push the scan axis to the front
input = jnp.moveaxis(input, in_axis, 0)
state, output = jax.lax.scan(scan_func, state, input, reverse=reverse)
# move the output axis to the desired location
output = jnp.moveaxis(output, 0, out_axis)
return output, state
return wrapper
# register state handlers
@tree_state.def_state(SimpleRNNCell)
def _(cell: SimpleRNNCell) -> SimpleRNNState:
return SimpleRNNState(jnp.zeros([cell.hidden_features]))
@tree_state.def_state(LinearCell)
def _(cell: LinearCell) -> LinearState:
return LinearState(jnp.empty([cell.hidden_features]))
@tree_state.def_state(LSTMCell)
def _(cell: LSTMCell) -> LSTMState:
shape = [cell.hidden_features]
return LSTMState(jnp.zeros(shape), jnp.zeros(shape))
@tree_state.def_state(GRUCell)
def _(cell: GRUCell) -> GRUState:
return GRUState(jnp.zeros([cell.hidden_features]))
def _check_rnn_cell_tree_state_input(cell, input):
if not (hasattr(input, "ndim") and hasattr(input, "shape")):
raise TypeError(
f"Expected {input=} to have `ndim` and `shape` attributes."
f"To initialize the `{type(cell).__name__}` state.\n"
"Pass a single sample input to `tree_state(..., input=)`."
)
if input.ndim != cell.spatial_ndim + 1:
raise ValueError(
f"{input.ndim=} != {(cell.spatial_ndim+1)=}.\n"
f"Expected input to {type(cell).__name__} to have `shape` (in_features, {'... '*cell.spatial_ndim}).\n"
f"Pass a single sample input to `tree_state({type(cell).__name__}, input=...)`"
)
if len(spatial_dim := input.shape[1:]) != cell.spatial_ndim:
raise ValueError(f"{len(spatial_dim)=} != {cell.spatial_ndim=}.")
return input
@tree_state.def_state(ConvLSTMNDCell)
def _(cell: ConvLSTMNDCell, input) -> ConvLSTMNDState:
input = _check_rnn_cell_tree_state_input(cell, input)
shape = (cell.hidden_features, *input.shape[1:])
zeros = jnp.zeros(shape).astype(input.dtype)
return ConvLSTMNDState(zeros, zeros)
@tree_state.def_state(ConvGRUNDCell)
def _(cell: ConvGRUNDCell, *, input: Any) -> ConvGRUNDState:
input = _check_rnn_cell_tree_state_input(cell, input)
shape = (cell.hidden_features, *input.shape[1:])
return ConvGRUNDState(jnp.zeros(shape).astype(input.dtype))