forward method

Tensor forward(
  1. Tensor x
)

Forward pass.

Input: (batch, 3, height, width) — preprocessed document image Output: (batch, numTokens, outputDim) — sequence of embeddings

Implementation

Tensor forward(Tensor x) {
  // Patch embedding: (batch, 3, H, W) → (batch, numPatches, embedDim)
  x = patchEmbed.forward(x);
  x = posDropout.forward(x);

  int h = patchH;
  int w = patchW;

  // Apply Swin layers
  for (int i = 0; i < layers.length; i++) {
    x = layers[i].forward(x, h, w);
    if (layers[i].patchMerging != null) {
      h = h ~/ 2;
      w = w ~/ 2;
    }
  }

  return x;
}