RNN class

A simple Recurrent Neural Network (RNN) layer.

An RNN is the fundamental building block for processing sequential data. It maintains a hidden state (or "memory") that is updated at each timestep by combining the current input with the hidden state from the previous step. This recurrent connection allows it to learn patterns over time.

The core operation at each timestep t is defined by the formula: `$h_t = \text{activation}(W_{xh}x_t + W_{hh}h_{t-1} + b_h)$

  • Input: A Tensor<Matrix> representing the sequence, with a shape of [sequence_length, input_size].
  • Output: A Tensor<Vector> representing the final hidden state after processing the entire sequence, with a shape of [hidden_size].

Example

// An RNN layer with 16 memory units and a Tanh activation.
Layer rnn = RNN(16, activation: Tanh());

// An input sequence with 3 timesteps and 5 features each.
Tensor<Matrix> sequence = Tensor<Matrix>([
  [0.1, 0.2, 0.3, 0.4, 0.5],
  [0.6, 0.7, 0.8, 0.9, 1.0],
  [0.5, 0.4, 0.3, 0.2, 0.1],
]);

// The output is the final hidden state vector of length 16.
Tensor<Vector> finalState = rnn.call(sequence) as Tensor<Vector>;
Inheritance

Constructors

RNN(int hiddenSize, {required ActivationFunction activation})

Properties

activation ActivationFunction
The non-linear activation function to apply to the hidden state. Tanh is the traditional choice for simple RNNs.
getter/setter pair
b_h Tensor<Vector>
The bias for the hidden state.
getter/setter pair
hashCode int
The hash code for this object.
no setterinherited
hiddenSize int
The number of units in the hidden state, representing the "memory" capacity.
getter/setter pair
name String
A user-friendly name for the layer (e.g., 'dense', 'lstm').
getter/setter pairoverride-getter
parameters List<Tensor>
Provides the three trainable parameters of the RNN to the optimizer.
no setteroverride
runtimeType Type
A representation of the runtime type of the object.
no setterinherited
W_hh Tensor<Matrix>
The hidden-to-hidden (recurrent) weight matrix.
getter/setter pair
W_xh Tensor<Matrix>
The input-to-hidden weight matrix.
getter/setter pair

Methods

build(Tensor input) → void
Initializes the W_xh, W_hh, and b_h parameter tensors.
override
call(Tensor input) Tensor
The public, callable interface for the layer.
inherited
forward(Tensor input) Tensor<Vector>
Performs the forward pass for the RNN layer.
override
noSuchMethod(Invocation invocation) → dynamic
Invoked when a nonexistent method or property is accessed.
inherited
toString() String
A string representation of this object.
inherited

Operators

operator ==(Object other) bool
The equality operator.
inherited