vectorSwish function
Mathematical operation for the Swish function on a vector.
Implementation
Tensor<Vector> vectorSwish(Tensor<Vector> v) {
int N = v.value.length;
Vector outValue = [];
Vector sigmoids = []; // Store sigmoid values for the backward pass
for (int i = 0; i < N; i++) {
double sigVal = 1.0 / (1.0 + exp(-v.value[i]));
sigmoids.add(sigVal);
outValue.add(v.value[i] * sigVal);
}
Tensor<Vector> out = Tensor<Vector>(outValue);
// The derivative of swish(x) is: sigmoid(x) * (1 + x * (1 - sigmoid(x)))
out.creator = Node([v], () {
for (int i = 0; i < N; i++) {
double sigVal = sigmoids[i];
double derivative = sigVal * (1 + v.value[i] * (1 - sigVal));
v.grad[i] += out.grad[i] * derivative;
}
}, opName: 'swish', cost: N * 2); // Roughly 2 ops per element
return out;
}