ggml_flash_attn_ext function

  1. @Native<Pointer<ggml_tensor> Function(Pointer<ggml_context>, Pointer<ggml_tensor>, Pointer<ggml_tensor>, Pointer<ggml_tensor>, Pointer<ggml_tensor>, Float, Float, Float)>(ffi.Pointer<ggml_context>, ffi.Pointer<ggml_tensor>, ffi.Pointer<ggml_tensor>, ffi.Pointer<ggml_tensor>, ffi.Pointer<ggml_tensor>, ffi.Float, ffi.Float, ffi.Float)>()
Pointer<ggml_tensor> ggml_flash_attn_ext(
  1. Pointer<ggml_context> ctx,
  2. Pointer<ggml_tensor> q,
  3. Pointer<ggml_tensor> k,
  4. Pointer<ggml_tensor> v,
  5. Pointer<ggml_tensor> mask,
  6. double scale,
  7. double max_bias,
  8. double logit_softcap,
)

q: n_embd_k, n_batch, n_head, ne3 k: n_embd_k, n_kv, n_head_kv, ne3 v: n_embd_v, n_kv, n_head_kv, ne3 !! not transposed !! mask: n_kv, n_batch, ne32, ne33 res: n_embd_v, n_head, n_batch, ne3 !! permuted !!

broadcast: n_head % n_head_kv == 0 n_head % ne32 == 0 ne3 % ne33 == 0

Implementation

@ffi.Native<
  ffi.Pointer<ggml_tensor> Function(
    ffi.Pointer<ggml_context>,
    ffi.Pointer<ggml_tensor>,
    ffi.Pointer<ggml_tensor>,
    ffi.Pointer<ggml_tensor>,
    ffi.Pointer<ggml_tensor>,
    ffi.Float,
    ffi.Float,
    ffi.Float,
  )
>()
external ffi.Pointer<ggml_tensor> ggml_flash_attn_ext(
  ffi.Pointer<ggml_context> ctx,
  ffi.Pointer<ggml_tensor> q,
  ffi.Pointer<ggml_tensor> k,
  ffi.Pointer<ggml_tensor> v,
  ffi.Pointer<ggml_tensor> mask,
  double scale,
  double max_bias,
  double logit_softcap,
);