static Tensor aft(Tensor q, Tensor k, Tensor v, Tensor wb, bool masked) { return Tensor._raw( engine.aftForward(q._handle, k._handle, v._handle, wb._handle, masked), q.shape, ); }