build method

  1. @override
void build(
  1. Tensor input
)
override

Initializes the layer's parameters based on the shape of the first input.

Subclasses should override this method to create their weights and biases. This method is called automatically by call and should not be called directly.

Implementation

@override
void build(Tensor<dynamic> input) {
  Matrix inputMatrix = input.value as Matrix;
  int inputSize = inputMatrix.isNotEmpty ? inputMatrix[0].length : 0;
  Random random = Random();

  // The lower tier's input is [h_lower, h_cell_higher, x_t]
  int lowerCombinedSize = hiddenSize + hiddenSize + inputSize;

  // The higher tier's input is [h_higher, h_lower]
  int higherCombinedSize = hiddenSize + hiddenSize;

  Tensor<Matrix> initWeights(int fanIn, int fanOut) {
    double stddev = sqrt(1.0 / fanIn);
    Matrix values = [];
    for (int i = 0; i < fanOut; i++) {
      Vector row = [];
      for (int j = 0; j < fanIn; j++) {
        row.add((random.nextDouble() * 2 - 1) * stddev);
      }
      values.add(row);
    }
    return Tensor<Matrix>(values);
  }

  // Initialize lower tier weights
  lW_f = initWeights(lowerCombinedSize, hiddenSize);
  lW_i = initWeights(lowerCombinedSize, hiddenSize);
  lW_c = initWeights(lowerCombinedSize, hiddenSize);
  lW_o = initWeights(lowerCombinedSize, hiddenSize);
  lb_f = Tensor<Vector>(List<double>.filled(hiddenSize, 0.0));
  lb_i = Tensor<Vector>(List<double>.filled(hiddenSize, 0.0));
  lb_c = Tensor<Vector>(List<double>.filled(hiddenSize, 0.0));
  lb_o = Tensor<Vector>(List<double>.filled(hiddenSize, 0.0));

  // Initialize higher tier weights
  hW_f = initWeights(higherCombinedSize, hiddenSize);
  hW_i = initWeights(higherCombinedSize, hiddenSize);
  hW_c = initWeights(higherCombinedSize, hiddenSize);
  hW_o = initWeights(higherCombinedSize, hiddenSize);
  hb_f = Tensor<Vector>(List<double>.filled(hiddenSize, 0.0));
  hb_i = Tensor<Vector>(List<double>.filled(hiddenSize, 0.0));
  hb_c = Tensor<Vector>(List<double>.filled(hiddenSize, 0.0));
  hb_o = Tensor<Vector>(List<double>.filled(hiddenSize, 0.0));

  super.build(input);
}