forward method

Tensor forward(
  1. Tensor x
)

Forward pass.

Input: (batch, channels, height, width) Output: (batch, numPatches, embedDim)

Implementation

Tensor forward(Tensor x) {
  final batch = x.shape[0];
  // Conv2d: (batch, 3, H, W) → (batch, embedDim, H/patch, W/patch)
  var out = proj.forward(x);
  final h = out.shape[2];
  final w = out.shape[3];
  // Reshape and permute: (batch, embedDim, h, w) → (batch, h*w, embedDim)
  out = out.reshape([batch, embedDim, h * w]).permute([0, 2, 1]);
  // LayerNorm
  out = norm.forward(out);
  return out;
}