batchNorm2dMath function

Tensor<Tensor3D> batchNorm2dMath(
  1. Tensor<Tensor3D> x,
  2. Tensor<Vector> gamma,
  3. Tensor<Vector> beta,
  4. Vector runningMean,
  5. Vector runningVariance,
  6. int numChannels,
  7. bool isTraining,
  8. double momentum,
  9. double epsilon,
)

Implementation

Tensor<Tensor3D> batchNorm2dMath(
    Tensor<Tensor3D> x,
    Tensor<Vector> gamma,
    Tensor<Vector> beta,
    Vector runningMean,
    Vector runningVariance,
    int numChannels,
    bool isTraining,
    double momentum,
    double epsilon,
    ) {
  int height = x.shape[1];
  int width = x.shape[2];
  int planeSize = height * width;
  double numElements = planeSize.toDouble();

  Vector currentMean = [];
  Vector currentVariance = [];
  for (int c = 0; c < numChannels; c = c + 1) {
    currentMean.add(0.0);
    currentVariance.add(0.0);
  }

  Vector meanToUse = [];
  Vector varianceToUse = [];

  if (isTraining) {
    for (int c = 0; c < numChannels; c = c + 1) {
      double sum = 0.0;
      int cOffset = c * planeSize;
      for (int i = 0; i < planeSize; i = i + 1) {
        sum = sum + x.data[cOffset + i];
      }
      currentMean[c] = sum / numElements;

      double varianceSum = 0.0;
      for (int i = 0; i < planeSize; i = i + 1) {
        double diff = x.data[cOffset + i] - currentMean[c];
        varianceSum = varianceSum + (diff * diff);
      }
      currentVariance[c] = varianceSum / numElements;
    }

    for (int c = 0; c < numChannels; c = c + 1) {
      runningMean[c] = momentum * runningMean[c] + (1.0 - momentum) * currentMean[c];
      runningVariance[c] = momentum * runningVariance[c] + (1.0 - momentum) * currentVariance[c];
    }
    meanToUse = currentMean;
    varianceToUse = currentVariance;
  } else {
    meanToUse = runningMean;
    varianceToUse = runningVariance;
  }

  List<double> xHatFlat = [];
  Tensor3D outValue = [];

  for (int c = 0; c < numChannels; c = c + 1) {
    Matrix m = [];
    double mean = meanToUse[c];
    double invStd = 1.0 / sqrt(varianceToUse[c] + epsilon);
    double gammaVal = gamma.data[c];
    double betaVal = beta.data[c];
    int cOffset = c * planeSize;

    for (int h = 0; h < height; h = h + 1) {
      Vector row = [];
      int hOffset = h * width;
      for (int w = 0; w < width; w = w + 1) {
        int flatIdx = cOffset + hOffset + w;
        double val = (x.data[flatIdx] - mean) * invStd;
        xHatFlat.add(val);
        row.add(gammaVal * val + betaVal);
      }
      m.add(row);
    }
    outValue.add(m);
  }

  Tensor<Tensor3D> out = Tensor<Tensor3D>(outValue);

  out.creator = Node(
    [x, gamma, beta],
        () {
      for (int c = 0; c < numChannels; c = c + 1) {
        double invStd = 1.0 / sqrt(varianceToUse[c] + epsilon);
        double gVal = gamma.data[c];
        int cOffset = c * planeSize;

        for (int i = 0; i < planeSize; i = i + 1) {
          int flatIdx = cOffset + i;
          double gradOut = out.grad[flatIdx];

          gamma.grad[c] = gamma.grad[c] + gradOut * xHatFlat[flatIdx];
          beta.grad[c] = beta.grad[c] + gradOut;
          x.grad[flatIdx] = x.grad[flatIdx] + gradOut * gVal * invStd;
        }
      }
    },
    opName: 'batch_norm_2d',
    cost: numChannels * height * width * 4,
  );

  return out;
}