forward method
Forward pass.
x: (batch, height*width, dim)
h, w: spatial dimensions
Implementation
Tensor forward(Tensor x, int h, int w) {
final batch = x.shape[0];
final n = h * w;
// ─── Window Attention ───────────────────────────────
var shortcut = x;
x = norm1.forward(x);
// Reshape to 2D: (batch, h, w, dim)
x = x.reshape([batch, h, w, dim]);
// Cyclic shift
Tensor? attnMask;
if (shiftSize > 0) {
x = _cyclicShift(x, -shiftSize, h, w);
attnMask = _computeAttnMask(h, w);
}
// Partition into windows: (numWindows*batch, windowSize, windowSize, dim)
final windows = _windowPartition(x, windowSize);
// Reshape windows to (numWindows*batch, windowSize*windowSize, dim)
final nw = windows.shape[0];
final windowedX = windows.reshape([nw, windowSize * windowSize, dim]);
// Window attention
var attnOutput = attn.forward(windowedX, mask: attnMask);
// Merge windows back: (batch, h, w, dim)
attnOutput = attnOutput.reshape([nw, windowSize, windowSize, dim]);
x = _windowReverse(attnOutput, windowSize, h, w, batch);
// Reverse cyclic shift
if (shiftSize > 0) {
x = _cyclicShift(x, shiftSize, h, w);
}
// Reshape back to (batch, h*w, dim)
x = x.reshape([batch, n, dim]);
// Residual connection
x = shortcut + x;
// ─── MLP ────────────────────────────────────────────
shortcut = x;
x = norm2.forward(x);
x = mlpFc1.forward(x);
x = x.gelu();
x = mlpFc2.forward(x);
x = shortcut + x;
return x;
}