batchNorm1dGPU function
Implementation
GPUTensor<Vector> batchNorm1dGPU(
GPUTensor<Vector> input,
GPUTensor<Vector> gamma,
GPUTensor<Vector> beta,
GPUTensor<Vector> runningMean,
GPUTensor<Vector> runningVariance,
double momentum,
double epsilon,
bool isTraining,
CommandBuffer tape,
) {
int numFeatures = input.shape[0];
GPUTensor<Vector> out = GPUTensor<Vector>(List<double>.filled(numFeatures, 0.0));
GPUTensor<Vector> savedMean = GPUTensor<Vector>(List<double>.filled(numFeatures, 0.0));
GPUTensor<Vector> savedInvVar = GPUTensor<Vector>(List<double>.filled(numFeatures, 0.0));
tape.putInt(OP_BATCH_NORM_1D_FORWARD);
tape.putString(input.id);
tape.putString(gamma.id);
tape.putString(beta.id);
tape.putString(runningMean.id);
tape.putString(runningVariance.id);
tape.putString(out.id);
tape.putString(savedMean.id);
tape.putString(savedInvVar.id);
tape.putFloat(momentum);
tape.putFloat(epsilon);
tape.putBool(isTraining);
out.creator = GPUNode(
[input, gamma, beta],
(CommandBuffer bTape) {
bTape.putInt(OP_BATCH_NORM_1D_BACKWARD);
bTape.putString('${out.id}_grad');
bTape.putString(input.id);
bTape.putString(gamma.id);
bTape.putString(savedMean.id);
bTape.putString(savedInvVar.id);
bTape.putString('${input.id}_grad');
bTape.putString('${gamma.id}_grad');
bTape.putString('${beta.id}_grad');
},
opName: 'batchNorm1dGPU',
cost: numFeatures * 4,
);
return out;
}