forward method

Tensor forward(
  1. Tensor x,
  2. int h,
  3. int w
)

Forward pass.

x: (batch, height*width, dim) h, w: spatial dimensions

Implementation

Tensor forward(Tensor x, int h, int w) {
  final batch = x.shape[0];
  final n = h * w;

  // ─── Window Attention ───────────────────────────────
  var shortcut = x;
  x = norm1.forward(x);

  // Reshape to 2D: (batch, h, w, dim)
  x = x.reshape([batch, h, w, dim]);

  // Cyclic shift
  Tensor? attnMask;
  if (shiftSize > 0) {
    x = _cyclicShift(x, -shiftSize, h, w);
    attnMask = _computeAttnMask(h, w);
  }

  // Partition into windows: (numWindows*batch, windowSize, windowSize, dim)
  final windows = _windowPartition(x, windowSize);

  // Reshape windows to (numWindows*batch, windowSize*windowSize, dim)
  final nw = windows.shape[0];
  final windowedX = windows.reshape([nw, windowSize * windowSize, dim]);

  // Window attention
  var attnOutput = attn.forward(windowedX, mask: attnMask);

  // Merge windows back: (batch, h, w, dim)
  attnOutput = attnOutput.reshape([nw, windowSize, windowSize, dim]);
  x = _windowReverse(attnOutput, windowSize, h, w, batch);

  // Reverse cyclic shift
  if (shiftSize > 0) {
    x = _cyclicShift(x, shiftSize, h, w);
  }

  // Reshape back to (batch, h*w, dim)
  x = x.reshape([batch, n, dim]);

  // Residual connection
  x = shortcut + x;

  // ─── MLP ────────────────────────────────────────────
  shortcut = x;
  x = norm2.forward(x);
  x = mlpFc1.forward(x);
  x = x.gelu();
  x = mlpFc2.forward(x);
  x = shortcut + x;

  return x;
}