WindowAttention constructor
WindowAttention(
{ - required int dim,
- required int windowSize,
- required int numHeads,
})
Implementation
WindowAttention({
required this.dim,
required this.windowSize,
required this.numHeads,
}) : headDim = dim ~/ numHeads {
qkv = Linear(dim, dim * 3);
proj = Linear(dim, dim);
final biasTableSize = (2 * windowSize - 1) * (2 * windowSize - 1);
relativePositionBiasTable = Tensor.zeros([biasTableSize, numHeads]);
// Compute relative position index
_computeRelativePositionIndex();
}