forward method
Forward pass.
Input: (batch, h*w, dim) Output: (batch, h'*w', dim') where h',w' may be halved
Implementation
Tensor forward(Tensor x, int h, int w) {
for (final block in blocks) {
x = block.forward(x, h, w);
}
if (patchMerging != null) {
x = patchMerging!.forward(x, h, w);
}
return x;
}