build method

  1. @override
void build(
  1. Tensor input
)
override

Initializes the 8 kernels and 4 biases for the LSTM gates.

Implementation

@override
void build(Tensor<dynamic> input) {
  Tensor3D inputSequence = input.value as Tensor3D;
  int height = inputSequence[0].length;
  int width = inputSequence[0][0].length;
  Random random = Random();

  Tensor<Matrix> initKernel(int size) {
    double stddev = sqrt(1.0 / (size * size));
    Matrix values = [];
    for (int i = 0; i < size; i++) {
      Vector row = [];
      for (int j = 0; j < size; j++) {
        row.add((random.nextDouble() * 2 - 1) * stddev);
      }
      values.add(row);
    }
    return Tensor<Matrix>(values);
  }

  // Since 'same' padding is used, the bias shape matches the input frame shape.
  Tensor<Matrix> initBias() {
    Matrix biasValues = [];
    for (int i = 0; i < height; i++) {
      Vector row = List<double>.filled(width, 0.0);
      biasValues.add(row);
    }
    return Tensor<Matrix>(biasValues);
  }

  K_xf = initKernel(kernelSize);
  K_hf = initKernel(kernelSize);
  K_xi = initKernel(kernelSize);
  K_hi = initKernel(kernelSize);
  K_xc = initKernel(kernelSize);
  K_hc = initKernel(kernelSize);
  K_xo = initKernel(kernelSize);
  K_ho = initKernel(kernelSize);

  b_f = initBias();
  b_i = initBias();
  b_c = initBias();
  b_o = initBias();

  super.build(input);
}