compile method

void compile(
  1. List<int> inputShape,
  2. List<int> targetShape,
  3. double learningRate
)

Implementation

void compile(List<int> inputShape, List<int> targetShape, double learningRate) {
  // 1. Create dummy input and target references to establish VRAM footprint
  inputRef = GPUTensor<Matrix>.empty(inputShape);
  targetRef = GPUTensor<Matrix>.empty(targetShape);

  // 2. Build Layers & Cascade Shapes
  GPUTensor<dynamic> current = inputRef;
  CommandBuffer buildTape = CommandBuffer(); // Throwaway tape
  List<GPUTensor> buildTrash = <GPUTensor>[]; // Throwaway intermediates

  for (int i = 0; i < layers.length; i = i + 1) {
    layers[i].build(current);
    current = layers[i].forward(current, buildTape, buildTrash);
  }

  for (int i = 0; i < buildTrash.length; i = i + 1) {
    buildTrash[i].free();
  }

  // 3. Populate allParams
  allParams.clear();
  for (int i = 0; i < layers.length; i = i + 1) {
    List<GPUTensor> params = layers[i].parameters;
    for (int j = 0; j < params.length; j = j + 1) {
      allParams.add(params[j]);
    }
  }

  // 4. Instantiate the optimizer
  optimizer = SGDGPU(allParams, learningRate);

  // 5a. Trace Forward Tape (fTape)
  CommandBuffer fCommand = CommandBuffer();
  intermediates.clear();
  current = inputRef;

  for (int i = 0; i < layers.length; i = i + 1) {
    current = layers[i].forward(current, fCommand, intermediates);
  }

  outputRef = current;
  lossRef = mseMatrixGPU(outputRef as GPUTensor<Matrix>, targetRef, fCommand);
  fTape = fCommand.bytes();

  // 5b. Trace Backward & ZeroGrad Tape (bTape)
  CommandBuffer bCommand = CommandBuffer();

  optimizer.zeroGrad(bCommand);
  for (int i = 0; i < intermediates.length; i = i + 1) {
    bCommand.putInt(OP_ZERO_GRAD);
    bCommand.putString('${intermediates[i].id}_grad');
  }

  bCommand.putInt(OP_ZERO_GRAD);
  bCommand.putString('${inputRef.id}_grad');
  bCommand.putInt(OP_ZERO_GRAD);
  bCommand.putString('${outputRef.id}_grad');

  // Triggers backpropagation and automatically fills loss grad with 1.0
  lossRef.backward(bCommand, fillOnes: true);

  for (int i = 0; i < allParams.length; i = i + 1) {
    bCommand.putInt(OP_CLIP_GRAD_VALUE);
    bCommand.putString('${allParams[i].id}_grad');
    bCommand.putFloat(1.0);
  }
  bTape = bCommand.bytes();

  // 5c. Trace Optimize Tape (oTape)
  CommandBuffer oCommand = CommandBuffer();
  optimizer.step(oCommand);
  oTape = oCommand.bytes();

  // 5d. Optional standalone ZeroGrad Tape (zTape) if you want manual control
  CommandBuffer zCommand = CommandBuffer();
  optimizer.zeroGrad(zCommand);
  zTape = zCommand.bytes();
}