compile method
Implementation
void compile(List<int> inputShape, List<int> targetShape, double learningRate) {
// 1. Create dummy input and target references to establish VRAM footprint
inputRef = GPUTensor<Matrix>.empty(inputShape);
targetRef = GPUTensor<Matrix>.empty(targetShape);
// 2. Build Layers & Cascade Shapes
GPUTensor<dynamic> current = inputRef;
CommandBuffer buildTape = CommandBuffer(); // Throwaway tape
List<GPUTensor> buildTrash = <GPUTensor>[]; // Throwaway intermediates
for (int i = 0; i < layers.length; i = i + 1) {
layers[i].build(current);
current = layers[i].forward(current, buildTape, buildTrash);
}
for (int i = 0; i < buildTrash.length; i = i + 1) {
buildTrash[i].free();
}
// 3. Populate allParams
allParams.clear();
for (int i = 0; i < layers.length; i = i + 1) {
List<GPUTensor> params = layers[i].parameters;
for (int j = 0; j < params.length; j = j + 1) {
allParams.add(params[j]);
}
}
// 4. Instantiate the optimizer
optimizer = SGDGPU(allParams, learningRate);
// 5a. Trace Forward Tape (fTape)
CommandBuffer fCommand = CommandBuffer();
intermediates.clear();
current = inputRef;
for (int i = 0; i < layers.length; i = i + 1) {
current = layers[i].forward(current, fCommand, intermediates);
}
outputRef = current;
lossRef = mseMatrixGPU(outputRef as GPUTensor<Matrix>, targetRef, fCommand);
fTape = fCommand.bytes();
// 5b. Trace Backward & ZeroGrad Tape (bTape)
CommandBuffer bCommand = CommandBuffer();
optimizer.zeroGrad(bCommand);
for (int i = 0; i < intermediates.length; i = i + 1) {
bCommand.putInt(OP_ZERO_GRAD);
bCommand.putString('${intermediates[i].id}_grad');
}
bCommand.putInt(OP_ZERO_GRAD);
bCommand.putString('${inputRef.id}_grad');
bCommand.putInt(OP_ZERO_GRAD);
bCommand.putString('${outputRef.id}_grad');
// Triggers backpropagation and automatically fills loss grad with 1.0
lossRef.backward(bCommand, fillOnes: true);
for (int i = 0; i < allParams.length; i = i + 1) {
bCommand.putInt(OP_CLIP_GRAD_VALUE);
bCommand.putString('${allParams[i].id}_grad');
bCommand.putFloat(1.0);
}
bTape = bCommand.bytes();
// 5c. Trace Optimize Tape (oTape)
CommandBuffer oCommand = CommandBuffer();
optimizer.step(oCommand);
oTape = oCommand.bytes();
// 5d. Optional standalone ZeroGrad Tape (zTape) if you want manual control
CommandBuffer zCommand = CommandBuffer();
optimizer.zeroGrad(zCommand);
zTape = zCommand.bytes();
}