forward method

Tensor forward(
  1. Tensor input
)

Forward pass: x @ weight^T + bias

Implementation

Tensor forward(Tensor input) {
  // input shape: (..., inFeatures)
  // weight shape: (outFeatures, inFeatures) → need W^T = (inFeatures, outFeatures)
  final wt = weight.transpose(0, 1); // (inFeatures, outFeatures)
  var output = input.matmul(wt); // (..., outFeatures)

  if (bias != null) {
    // Broadcast bias across batch dimensions
    output = output +
        bias!.reshape([
          for (int i = 0; i < output.ndim - 1; i++) 1,
          outFeatures
        ]).expand(output.shape);
  }
  return output;
}