batchNorm2dMath function
Implementation
Tensor<Tensor3D> batchNorm2dMath(
Tensor<Tensor3D> x,
Tensor<Vector> gamma,
Tensor<Vector> beta,
Vector runningMean,
Vector runningVariance,
int numChannels,
bool isTraining,
double momentum,
double epsilon,
) {
int height = x.shape[1];
int width = x.shape[2];
int planeSize = height * width;
double numElements = planeSize.toDouble();
Vector currentMean = [];
Vector currentVariance = [];
for (int c = 0; c < numChannels; c = c + 1) {
currentMean.add(0.0);
currentVariance.add(0.0);
}
Vector meanToUse = [];
Vector varianceToUse = [];
if (isTraining) {
for (int c = 0; c < numChannels; c = c + 1) {
double sum = 0.0;
int cOffset = c * planeSize;
for (int i = 0; i < planeSize; i = i + 1) {
sum = sum + x.data[cOffset + i];
}
currentMean[c] = sum / numElements;
double varianceSum = 0.0;
for (int i = 0; i < planeSize; i = i + 1) {
double diff = x.data[cOffset + i] - currentMean[c];
varianceSum = varianceSum + (diff * diff);
}
currentVariance[c] = varianceSum / numElements;
}
for (int c = 0; c < numChannels; c = c + 1) {
runningMean[c] = momentum * runningMean[c] + (1.0 - momentum) * currentMean[c];
runningVariance[c] = momentum * runningVariance[c] + (1.0 - momentum) * currentVariance[c];
}
meanToUse = currentMean;
varianceToUse = currentVariance;
} else {
meanToUse = runningMean;
varianceToUse = runningVariance;
}
List<double> xHatFlat = [];
Tensor3D outValue = [];
for (int c = 0; c < numChannels; c = c + 1) {
Matrix m = [];
double mean = meanToUse[c];
double invStd = 1.0 / sqrt(varianceToUse[c] + epsilon);
double gammaVal = gamma.data[c];
double betaVal = beta.data[c];
int cOffset = c * planeSize;
for (int h = 0; h < height; h = h + 1) {
Vector row = [];
int hOffset = h * width;
for (int w = 0; w < width; w = w + 1) {
int flatIdx = cOffset + hOffset + w;
double val = (x.data[flatIdx] - mean) * invStd;
xHatFlat.add(val);
row.add(gammaVal * val + betaVal);
}
m.add(row);
}
outValue.add(m);
}
Tensor<Tensor3D> out = Tensor<Tensor3D>(outValue);
out.creator = Node(
[x, gamma, beta],
() {
for (int c = 0; c < numChannels; c = c + 1) {
double invStd = 1.0 / sqrt(varianceToUse[c] + epsilon);
double gVal = gamma.data[c];
int cOffset = c * planeSize;
for (int i = 0; i < planeSize; i = i + 1) {
int flatIdx = cOffset + i;
double gradOut = out.grad[flatIdx];
gamma.grad[c] = gamma.grad[c] + gradOut * xHatFlat[flatIdx];
beta.grad[c] = beta.grad[c] + gradOut;
x.grad[flatIdx] = x.grad[flatIdx] + gradOut * gVal * invStd;
}
}
},
opName: 'batch_norm_2d',
cost: numChannels * height * width * 4,
);
return out;
}