layerNorm function

Tensor<Matrix> layerNorm(
  1. Tensor<Matrix> m,
  2. Tensor<Vector> gamma,
  3. Tensor<Vector> beta, {
  4. double epsilon = 1e-5,
})

Implementation

Tensor<Matrix> layerNorm(Tensor<Matrix> m, Tensor<Vector> gamma, Tensor<Vector> beta, {double epsilon = 1e-5}) {
  int numRows = m.value.length;
  int numCols = m.value[0].length;
  Matrix outValue = [];

  Matrix normalizedRows = [];
  Vector means = [];
  Vector variances = [];

  for (int r = 0; r < numRows; r++) {
    Vector row = m.value[r];
    double mean = 0;
    for (double val in row) { mean += val; }
    mean /= numCols;
    means.add(mean);

    double variance = 0;
    for (double val in row) { variance += pow(val - mean, 2); }
    variance /= numCols;
    variances.add(variance);

    Vector normalizedRow = [];
    for (double val in row) {
      normalizedRow.add((val - mean) / sqrt(variance + epsilon));
    }
    normalizedRows.add(normalizedRow);

    Vector finalRow = [];
    for (int c = 0; c < numCols; c++) {
      finalRow.add(gamma.value[c] * normalizedRow[c] + beta.value[c]);
    }
    outValue.add(finalRow);
  }

  Tensor<Matrix> out = Tensor<Matrix>(outValue);
  int cost = numRows * numCols * 8;

  out.creator = Node([m, gamma, beta], () {
    for(int r = 0; r < numRows; r++){
      Vector grad_x_hat = [];
      for(int c=0; c < numCols; c++){
        grad_x_hat.add(out.grad[r][c] * gamma.value[c]);
        gamma.grad[c] += out.grad[r][c] * normalizedRows[r][c];
        beta.grad[c] += out.grad[r][c];
      }

      double sum_grad_x_hat = 0;
      for (double val in grad_x_hat) { sum_grad_x_hat += val; }

      double dot_product_term = 0;
      for (int c = 0; c < numCols; c++) {
        dot_product_term += grad_x_hat[c] * normalizedRows[r][c];
      }

      for (int c = 0; c < numCols; c++) {
        double term1 = numCols * grad_x_hat[c];
        double term2 = sum_grad_x_hat;
        double term3 = normalizedRows[r][c] * dot_product_term;

        double total_grad = (1.0 / (numCols * sqrt(variances[r] + epsilon))) * (term1 - term2 - term3);
        m.grad[r][c] += total_grad;
      }
    }
  }, opName: 'layer_norm', cost: cost);
  return out;
}