scatterHeadsGPU function

GPUTensor<Matrix> scatterHeadsGPU(
  1. List<GPUTensor<Matrix>> heads,
  2. int dModel,
  3. CommandBuffer tape
)

Implementation

GPUTensor<Matrix> scatterHeadsGPU(List<GPUTensor<Matrix>> heads, int dModel, CommandBuffer tape) {
  int seqLen = heads[0].shape[0];
  int dHead = heads[0].shape[1];

  GPUTensor<Matrix> out = GPUTensor<Matrix>.empty([seqLen, dModel]);

  tape.putInt(OP_FILL);
  tape.putString(out.id);
  tape.putFloat(0.0);

  for (int i = 0; i < heads.length; i = i + 1) {
    int startCol = i * dHead;
    int endCol = startCol + dHead;

    tape.putInt(OP_SLICE_COLUMN_BACKWARD);
    tape.putString(heads[i].id);
    tape.putString(out.id);
    tape.putInt(startCol);
    tape.putInt(endCol);
  }

  out.creator = GPUNode(
    [...heads],
        (CommandBuffer bTape) {
      for (int i = 0; i < heads.length; i = i + 1) {
        int startCol = i * dHead;
        int endCol = startCol + dHead;

        bTape.putInt(OP_SLICE_COLUMN);
        bTape.putString('${out.id}_grad');
        bTape.putString('${heads[i].id}_grad');
        bTape.putInt(startCol);
        bTape.putInt(endCol);
      }
    },
    opName: 'scatter_heads',
    cost: seqLen * dModel,
  );

  return out;
}