batchNorm2dGPU function
Implementation
GPUTensor<Tensor3D> batchNorm2dGPU(
GPUTensor<Tensor3D> input,
GPUTensor<Vector> gamma,
GPUTensor<Vector> beta,
GPUTensor<Vector> runningMean,
GPUTensor<Vector> runningVariance,
double momentum,
double epsilon,
bool isTraining,
CommandBuffer tape,
) {
int numChannels = input.shape[0];
int height = input.shape[1];
int width = input.shape[2];
GPUTensor<Tensor3D> out = GPUTensor<Tensor3D>.empty(<int>[numChannels, height, width]);
GPUTensor<Vector> savedMean = GPUTensor<Vector>(List<double>.filled(numChannels, 0.0));
GPUTensor<Vector> savedInvVar = GPUTensor<Vector>(List<double>.filled(numChannels, 0.0));
tape.putInt(OP_BATCH_NORM_2D_FORWARD);
tape.putString(input.id);
tape.putString(gamma.id);
tape.putString(beta.id);
tape.putString(runningMean.id);
tape.putString(runningVariance.id);
tape.putString(out.id);
tape.putString(savedMean.id);
tape.putString(savedInvVar.id);
tape.putFloat(momentum);
tape.putFloat(epsilon);
tape.putBool(isTraining);
out.creator = GPUNode(
[input, gamma, beta],
(CommandBuffer bTape) {
bTape.putInt(OP_BATCH_NORM_2D_BACKWARD);
bTape.putString('${out.id}_grad');
bTape.putString(input.id);
bTape.putString(gamma.id);
bTape.putString(savedMean.id);
bTape.putString(savedInvVar.id);
bTape.putString('${input.id}_grad');
bTape.putString('${gamma.id}_grad');
bTape.putString('${beta.id}_grad');
},
opName: 'batchNorm2dGPU',
cost: numChannels * height * width * 4,
);
return out;
}