WindowAttention constructor

WindowAttention({
  1. required int dim,
  2. required int windowSize,
  3. required int numHeads,
})

Implementation

WindowAttention({
  required this.dim,
  required this.windowSize,
  required this.numHeads,
}) : headDim = dim ~/ numHeads {
  qkv = Linear(dim, dim * 3);
  proj = Linear(dim, dim);

  final biasTableSize = (2 * windowSize - 1) * (2 * windowSize - 1);
  relativePositionBiasTable = Tensor.zeros([biasTableSize, numHeads]);

  // Compute relative position index
  _computeRelativePositionIndex();
}