adamUpdateGPU function

void adamUpdateGPU(
  1. GPUTensor data,
  2. GPUTensor m,
  3. GPUTensor v,
  4. double lr,
  5. double beta1,
  6. double beta2,
  7. double eps,
  8. int step,
  9. double weightDecay,
  10. CommandBuffer tape,
)

Implementation

void adamUpdateGPU(
    GPUTensor<dynamic> data,
    GPUTensor<dynamic> m,
    GPUTensor<dynamic> v,
    double lr,
    double beta1,
    double beta2,
    double eps,
    int step,
    double weightDecay,
    CommandBuffer tape) {

  tape.putInt(OP_ADAM_UPDATE);
  tape.putString(data.id);
  tape.putString('${data.id}_grad');
  tape.putString(m.id);
  tape.putString(v.id);
  tape.putFloat(lr);
  tape.putFloat(beta1);
  tape.putFloat(beta2);
  tape.putFloat(eps);
  tape.putInt(step);
  tape.putFloat(weightDecay);
}