forward method
Forward pass.
Input: (batch, hw, dim) with h, w as spatial dims Output: (batch, h/2w/2, 2*dim)
Implementation
Tensor forward(Tensor x, int h, int w) {
final batch = x.shape[0];
// Reshape to spatial: (batch, h, w, dim)
x = x.reshape([batch, h, w, dim]);
final newH = h ~/ 2;
final newW = w ~/ 2;
// Extract 4 sub-grids
final outData = Float32List(batch * newH * newW * 4 * dim);
for (int b = 0; b < batch; b++) {
for (int i = 0; i < newH; i++) {
for (int j = 0; j < newW; j++) {
final dstBase = ((b * newH + i) * newW + j) * 4 * dim;
// Top-left
_copyPatch(x, b, 2 * i, 2 * j, w, dim, outData, dstBase);
// Top-right
_copyPatch(x, b, 2 * i, 2 * j + 1, w, dim, outData, dstBase + dim);
// Bottom-left
_copyPatch(
x, b, 2 * i + 1, 2 * j, w, dim, outData, dstBase + 2 * dim);
// Bottom-right
_copyPatch(
x, b, 2 * i + 1, 2 * j + 1, w, dim, outData, dstBase + 3 * dim);
}
}
}
var merged = Tensor(outData, [batch, newH * newW, 4 * dim]);
// Layer norm then linear projection
merged = norm.forward(merged);
merged = reduction.forward(merged);
return merged;
}