forward method

Tensor forward(
  1. Tensor x,
  2. int h,
  3. int w
)

Forward pass.

Input: (batch, hw, dim) with h, w as spatial dims Output: (batch, h/2w/2, 2*dim)

Implementation

Tensor forward(Tensor x, int h, int w) {
  final batch = x.shape[0];

  // Reshape to spatial: (batch, h, w, dim)
  x = x.reshape([batch, h, w, dim]);

  final newH = h ~/ 2;
  final newW = w ~/ 2;

  // Extract 4 sub-grids
  final outData = Float32List(batch * newH * newW * 4 * dim);

  for (int b = 0; b < batch; b++) {
    for (int i = 0; i < newH; i++) {
      for (int j = 0; j < newW; j++) {
        final dstBase = ((b * newH + i) * newW + j) * 4 * dim;
        // Top-left
        _copyPatch(x, b, 2 * i, 2 * j, w, dim, outData, dstBase);
        // Top-right
        _copyPatch(x, b, 2 * i, 2 * j + 1, w, dim, outData, dstBase + dim);
        // Bottom-left
        _copyPatch(
            x, b, 2 * i + 1, 2 * j, w, dim, outData, dstBase + 2 * dim);
        // Bottom-right
        _copyPatch(
            x, b, 2 * i + 1, 2 * j + 1, w, dim, outData, dstBase + 3 * dim);
      }
    }
  }

  var merged = Tensor(outData, [batch, newH * newW, 4 * dim]);

  // Layer norm then linear projection
  merged = norm.forward(merged);
  merged = reduction.forward(merged);

  return merged;
}