scatterHeadsGPU function
Implementation
GPUTensor<Matrix> scatterHeadsGPU(List<GPUTensor<Matrix>> heads, int dModel, CommandBuffer tape) {
int seqLen = heads[0].shape[0];
int dHead = heads[0].shape[1];
GPUTensor<Matrix> out = GPUTensor<Matrix>.empty([seqLen, dModel]);
tape.putInt(OP_FILL);
tape.putString(out.id);
tape.putFloat(0.0);
for (int i = 0; i < heads.length; i = i + 1) {
int startCol = i * dHead;
int endCol = startCol + dHead;
tape.putInt(OP_SLICE_COLUMN_BACKWARD);
tape.putString(heads[i].id);
tape.putString(out.id);
tape.putInt(startCol);
tape.putInt(endCol);
}
out.creator = GPUNode(
[...heads],
(CommandBuffer bTape) {
for (int i = 0; i < heads.length; i = i + 1) {
int startCol = i * dHead;
int endCol = startCol + dHead;
bTape.putInt(OP_SLICE_COLUMN);
bTape.putString('${out.id}_grad');
bTape.putString('${heads[i].id}_grad');
bTape.putInt(startCol);
bTape.putInt(endCol);
}
},
opName: 'scatter_heads',
cost: seqLen * dModel,
);
return out;
}