forward method

Tensor forward(
  1. Tensor x, {
  2. Tensor? mask,
})

Forward pass.

x: (numWindowsbatch, windowSizewindowSize, dim) mask: optional attention mask (numWindows, windowSizewindowSize, windowSizewindowSize)

Implementation

Tensor forward(Tensor x, {Tensor? mask}) {
  final bw = x.shape[0]; // numWindows * batch
  final n = x.shape[1]; // windowSize * windowSize
  final scale = 1.0 / math.sqrt(headDim.toDouble());

  // QKV projection
  var qkvOut = qkv.forward(x); // (bw, n, 3*dim)
  qkvOut = qkvOut.reshape([bw, n, 3, numHeads, headDim]);
  qkvOut = qkvOut.permute([2, 0, 3, 1, 4]); // (3, bw, numHeads, n, headDim)

  final q = qkvOut[0]; // (bw, numHeads, n, headDim)
  final k = qkvOut[1];
  final v = qkvOut[2];

  // Attention scores
  final kT = k.transpose(2, 3); // (bw, numHeads, headDim, n)
  var attn = q.matmul(kT).mulScalar(scale); // (bw, numHeads, n, n)

  // Add relative position bias
  final bias = _getRelativePositionBias();
  attn = attn + bias.unsqueeze(0).expand(attn.shape);

  // Apply window mask if provided
  if (mask != null) {
    final numWindows = mask.shape[0];
    final batch = bw ~/ numWindows;
    attn = attn.reshape([batch, numWindows, numHeads, n, n]);
    final expandedMask = mask.unsqueeze(1).unsqueeze(0).expand(attn.shape);
    attn = attn + expandedMask;
    attn = attn.reshape([bw, numHeads, n, n]);
  }

  // Softmax
  attn = attn.softmax(-1);

  // Apply attention to values
  var output = attn.matmul(v); // (bw, numHeads, n, headDim)
  output = output.permute([0, 2, 1, 3]).reshape([bw, n, dim]);

  // Output projection
  return proj.forward(output);
}