mlx_random_multivariate_normal function
int
mlx_random_multivariate_normal(
- Pointer<
mlx_array> res, - mlx_array mean,
- mlx_array cov,
- Pointer<
Int> shape, - int shape_num,
- mlx_dtype_ dtype,
- mlx_array key,
- mlx_stream s,
Implementation
int mlx_random_multivariate_normal(
ffi.Pointer<mlx_array> res,
mlx_array mean,
mlx_array cov,
ffi.Pointer<ffi.Int> shape,
int shape_num,
mlx_dtype_ dtype,
mlx_array key,
mlx_stream s,
) => _mlx_random_multivariate_normal(
res,
mean,
cov,
shape,
shape_num,
dtype.value,
key,
s,
);