mlx_random_multivariate_normal function

int mlx_random_multivariate_normal(
  1. Pointer<mlx_array> res,
  2. mlx_array mean,
  3. mlx_array cov,
  4. Pointer<Int> shape,
  5. int shape_num,
  6. mlx_dtype_ dtype,
  7. mlx_array key,
  8. mlx_stream s,
)

Implementation

int mlx_random_multivariate_normal(
  ffi.Pointer<mlx_array> res,
  mlx_array mean,
  mlx_array cov,
  ffi.Pointer<ffi.Int> shape,
  int shape_num,
  mlx_dtype_ dtype,
  mlx_array key,
  mlx_stream s,
) => _mlx_random_multivariate_normal(
  res,
  mean,
  cov,
  shape,
  shape_num,
  dtype.value,
  key,
  s,
);