mishVector function

Tensor<Vector> mishVector(
  1. Tensor<Vector> v
)

Implementation

Tensor<Vector> mishVector(Tensor<Vector> v) {
  int N = v.data.length;
  Vector outValue = [];

  for (int i = 0; i < N; i = i + 1) {
    double x = v.data[i];
    double sp = log(1.0 + exp(x));
    double e2sp = exp(2.0 * sp);
    double t = e2sp.isInfinite ? 1.0 : (e2sp - 1.0) / (e2sp + 1.0);
    outValue.add(x * t);
  }

  Tensor<Vector> out = Tensor<Vector>(outValue);

  out.creator = Node(
    [v],
        () {
      for (int i = 0; i < N; i = i + 1) {
        double x = v.data[i];
        double sp = log(1.0 + exp(x));
        double e2sp = exp(2.0 * sp);
        double t = e2sp.isInfinite ? 1.0 : (e2sp - 1.0) / (e2sp + 1.0);
        double s = 1.0 / (1.0 + exp(-x));
        double grad = t + x * s * (1.0 - t * t);
        v.grad[i] = v.grad[i] + out.grad[i] * grad;
      }
    },
    opName: 'mish_vector',
    cost: N * 3,
  );

  return out;
}