ggml_flash_attn_ext function
- @Native<Pointer<
ggml_tensor> Function(Pointer<ggml_context> , Pointer<ggml_tensor> , Pointer<ggml_tensor> , Pointer<ggml_tensor> , Pointer<ggml_tensor> , Float, Float, Float)>(ffi.Pointer<ggml_context>, ffi.Pointer<ggml_tensor>, ffi.Pointer<ggml_tensor>, ffi.Pointer<ggml_tensor>, ffi.Pointer<ggml_tensor>, ffi.Float, ffi.Float, ffi.Float)>()
- Pointer<
ggml_context> ctx, - Pointer<
ggml_tensor> q, - Pointer<
ggml_tensor> k, - Pointer<
ggml_tensor> v, - Pointer<
ggml_tensor> mask, - double scale,
- double max_bias,
- double logit_softcap,
q: n_embd_k, n_batch, n_head, ne3
k: n_embd_k, n_kv, n_head_kv, ne3
v: n_embd_v, n_kv, n_head_kv, ne3 !! not transposed !!
mask: n_kv, n_batch, ne32, ne33
res: n_embd_v, n_head, n_batch, ne3 !! permuted !!
broadcast: n_head % n_head_kv == 0 n_head % ne32 == 0 ne3 % ne33 == 0
Implementation
@ffi.Native<
ffi.Pointer<ggml_tensor> Function(
ffi.Pointer<ggml_context>,
ffi.Pointer<ggml_tensor>,
ffi.Pointer<ggml_tensor>,
ffi.Pointer<ggml_tensor>,
ffi.Pointer<ggml_tensor>,
ffi.Float,
ffi.Float,
ffi.Float,
)
>()
external ffi.Pointer<ggml_tensor> ggml_flash_attn_ext(
ffi.Pointer<ggml_context> ctx,
ffi.Pointer<ggml_tensor> q,
ffi.Pointer<ggml_tensor> k,
ffi.Pointer<ggml_tensor> v,
ffi.Pointer<ggml_tensor> mask,
double scale,
double max_bias,
double logit_softcap,
);