forward method

Tensor forward(
  1. Tensor x,
  2. int h,
  3. int w
)

Forward pass.

Input: (batch, h*w, dim) Output: (batch, h'*w', dim') where h',w' may be halved

Implementation

Tensor forward(Tensor x, int h, int w) {
  for (final block in blocks) {
    x = block.forward(x, h, w);
  }

  if (patchMerging != null) {
    x = patchMerging!.forward(x, h, w);
  }

  return x;
}