vectorSwish function

Tensor<Vector> vectorSwish(
  1. Tensor<Vector> v
)

Mathematical operation for the Swish function on a vector.

Implementation

Tensor<Vector> vectorSwish(Tensor<Vector> v) {
  int N = v.value.length;
  Vector outValue = [];
  Vector sigmoids = []; // Store sigmoid values for the backward pass

  for (int i = 0; i < N; i++) {
    double sigVal = 1.0 / (1.0 + exp(-v.value[i]));
    sigmoids.add(sigVal);
    outValue.add(v.value[i] * sigVal);
  }

  Tensor<Vector> out = Tensor<Vector>(outValue);

  // The derivative of swish(x) is: sigmoid(x) * (1 + x * (1 - sigmoid(x)))
  out.creator = Node([v], () {
    for (int i = 0; i < N; i++) {
      double sigVal = sigmoids[i];
      double derivative = sigVal * (1 + v.value[i] * (1 - sigVal));
      v.grad[i] += out.grad[i] * derivative;
    }
  }, opName: 'swish', cost: N * 2); // Roughly 2 ops per element
  return out;
}