forward method
Forward pass.
x: (numWindowsbatch, windowSizewindowSize, dim)
mask: optional attention mask (numWindows, windowSizewindowSize, windowSizewindowSize)
Implementation
Tensor forward(Tensor x, {Tensor? mask}) {
final bw = x.shape[0]; // numWindows * batch
final n = x.shape[1]; // windowSize * windowSize
final scale = 1.0 / math.sqrt(headDim.toDouble());
// QKV projection
var qkvOut = qkv.forward(x); // (bw, n, 3*dim)
qkvOut = qkvOut.reshape([bw, n, 3, numHeads, headDim]);
qkvOut = qkvOut.permute([2, 0, 3, 1, 4]); // (3, bw, numHeads, n, headDim)
final q = qkvOut[0]; // (bw, numHeads, n, headDim)
final k = qkvOut[1];
final v = qkvOut[2];
// Attention scores
final kT = k.transpose(2, 3); // (bw, numHeads, headDim, n)
var attn = q.matmul(kT).mulScalar(scale); // (bw, numHeads, n, n)
// Add relative position bias
final bias = _getRelativePositionBias();
attn = attn + bias.unsqueeze(0).expand(attn.shape);
// Apply window mask if provided
if (mask != null) {
final numWindows = mask.shape[0];
final batch = bw ~/ numWindows;
attn = attn.reshape([batch, numWindows, numHeads, n, n]);
final expandedMask = mask.unsqueeze(1).unsqueeze(0).expand(attn.shape);
attn = attn + expandedMask;
attn = attn.reshape([bw, numHeads, n, n]);
}
// Softmax
attn = attn.softmax(-1);
// Apply attention to values
var output = attn.matmul(v); // (bw, numHeads, n, headDim)
output = output.permute([0, 2, 1, 3]).reshape([bw, n, dim]);
// Output projection
return proj.forward(output);
}