forward method
Forward pass.
Input: (batch, channels, height, width) Output: (batch, numPatches, embedDim)
Implementation
Tensor forward(Tensor x) {
final batch = x.shape[0];
// Conv2d: (batch, 3, H, W) → (batch, embedDim, H/patch, W/patch)
var out = proj.forward(x);
final h = out.shape[2];
final w = out.shape[3];
// Reshape and permute: (batch, embedDim, h, w) → (batch, h*w, embedDim)
out = out.reshape([batch, embedDim, h * w]).permute([0, 2, 1]);
// LayerNorm
out = norm.forward(out);
return out;
}