batchNorm1dMath function
Implementation
Tensor<Vector> batchNorm1dMath(
Tensor<Vector> x,
Tensor<Vector> gamma,
Tensor<Vector> beta,
Vector runningMean,
Vector runningVariance,
int numFeatures,
bool isTraining,
double momentum,
double epsilon,
) {
Vector xHat = [];
Vector currentMean = [];
Vector currentVariance = [];
if (isTraining) {
for (int i = 0; i < numFeatures; i = i + 1) {
currentMean.add(x.data[i]);
currentVariance.add(0.0);
}
for (int i = 0; i < numFeatures; i = i + 1) {
runningMean[i] = momentum * runningMean[i] + (1.0 - momentum) * currentMean[i];
runningVariance[i] = momentum * runningVariance[i] + (1.0 - momentum) * currentVariance[i];
}
} else {
for (int i = 0; i < numFeatures; i = i + 1) {
currentMean.add(runningMean[i]);
currentVariance.add(runningVariance[i]);
}
}
Vector varianceToUse = isTraining ? runningVariance : currentVariance;
Vector meanToUse = isTraining ? runningMean : currentMean;
for (int i = 0; i < numFeatures; i = i + 1) {
xHat.add((x.data[i] - meanToUse[i]) / sqrt(varianceToUse[i] + epsilon));
}
Vector outValue = [];
for (int i = 0; i < numFeatures; i = i + 1) {
outValue.add(gamma.data[i] * xHat[i] + beta.data[i]);
}
Tensor<Vector> out = Tensor<Vector>(outValue);
out.creator = Node(
[x, gamma, beta],
() {
for (int i = 0; i < numFeatures; i = i + 1) {
double invStd = 1.0 / sqrt(varianceToUse[i] + epsilon);
gamma.grad[i] = gamma.grad[i] + out.grad[i] * xHat[i];
beta.grad[i] = beta.grad[i] + out.grad[i];
x.grad[i] = x.grad[i] + out.grad[i] * gamma.data[i] * invStd;
}
},
opName: 'batch_norm_1d',
cost: numFeatures * 4,
);
return out;
}