batchNorm1dMath function

Tensor<Vector> batchNorm1dMath(
  1. Tensor<Vector> x,
  2. Tensor<Vector> gamma,
  3. Tensor<Vector> beta,
  4. Vector runningMean,
  5. Vector runningVariance,
  6. int numFeatures,
  7. bool isTraining,
  8. double momentum,
  9. double epsilon,
)

Implementation

Tensor<Vector> batchNorm1dMath(
    Tensor<Vector> x,
    Tensor<Vector> gamma,
    Tensor<Vector> beta,
    Vector runningMean,
    Vector runningVariance,
    int numFeatures,
    bool isTraining,
    double momentum,
    double epsilon,
    ) {
  Vector xHat = [];
  Vector currentMean = [];
  Vector currentVariance = [];

  if (isTraining) {
    for (int i = 0; i < numFeatures; i = i + 1) {
      currentMean.add(x.data[i]);
      currentVariance.add(0.0);
    }

    for (int i = 0; i < numFeatures; i = i + 1) {
      runningMean[i] = momentum * runningMean[i] + (1.0 - momentum) * currentMean[i];
      runningVariance[i] = momentum * runningVariance[i] + (1.0 - momentum) * currentVariance[i];
    }
  } else {
    for (int i = 0; i < numFeatures; i = i + 1) {
      currentMean.add(runningMean[i]);
      currentVariance.add(runningVariance[i]);
    }
  }

  Vector varianceToUse = isTraining ? runningVariance : currentVariance;
  Vector meanToUse = isTraining ? runningMean : currentMean;

  for (int i = 0; i < numFeatures; i = i + 1) {
    xHat.add((x.data[i] - meanToUse[i]) / sqrt(varianceToUse[i] + epsilon));
  }

  Vector outValue = [];
  for (int i = 0; i < numFeatures; i = i + 1) {
    outValue.add(gamma.data[i] * xHat[i] + beta.data[i]);
  }

  Tensor<Vector> out = Tensor<Vector>(outValue);
  out.creator = Node(
    [x, gamma, beta],
        () {
      for (int i = 0; i < numFeatures; i = i + 1) {
        double invStd = 1.0 / sqrt(varianceToUse[i] + epsilon);
        gamma.grad[i] = gamma.grad[i] + out.grad[i] * xHat[i];
        beta.grad[i] = beta.grad[i] + out.grad[i];
        x.grad[i] = x.grad[i] + out.grad[i] * gamma.data[i] * invStd;
      }
    },
    opName: 'batch_norm_1d',
    cost: numFeatures * 4,
  );

  return out;
}