forward method

Tensor forward(
  1. Tensor x,
  2. List<Tensor> tracker
)

Implementation

Tensor forward(Tensor x, List<Tensor> tracker) {
  // 1. xw = x * w
  final xw = x.matmul(w);

  // 2. out = xw + b (Broadcasting handled by our C++ max(shape) logic)
  final out = xw + b;

  tracker.addAll([xw, out]);

  if (useGelu) {
    // 3. act = gelu(out)
    final act = out.gelu();
    tracker.add(act);
    return act;
  }

  return out;
}