dart_mlx_ffi_bindings_generated library

Classes

mlx_array_
A N-dimensional array object.
mlx_closure_
\defgroup mlx_closure Closures MLX closure objects. / /**@{
mlx_closure_custom_
mlx_closure_custom_jvp_
mlx_closure_custom_vmap_
mlx_closure_kwargs_
mlx_closure_value_and_grad_
mlx_complex64_t
C float _Complex is laid out as two adjacent float values.
mlx_device_
A MLX device object.
mlx_device_info_
A MLX device info object. Contains key-value pairs with device properties. Keys vary by backend but common keys include:
mlx_distributed_group_
A MLX distributed group object.
mlx_fast_cuda_kernel_
mlx_fast_cuda_kernel_config_
\defgroup fast Fast custom operations / /**@{
mlx_fast_metal_kernel_
mlx_fast_metal_kernel_config_
mlx_function_exporter_
mlx_imported_function_
mlx_io_reader_
A MLX IO reader object.
mlx_io_vtable_
Virtual table for custom IO reader and writer objects.
mlx_io_writer_
A MLX IO writer object.
mlx_map_string_to_array_
A string-to-array map
mlx_map_string_to_array_iterator_
An iterator over a string-to-array map.
mlx_map_string_to_string_
A string-to-string map
mlx_map_string_to_string_iterator_
An iterator over a string-to-string map.
mlx_optional_dtype_
A dtype optional.
mlx_optional_float_
A float optional.
mlx_optional_int_
A int optional.
mlx_stream_
A MLX stream object.
mlx_string_
A MLX string object.
mlx_vector_array_
A vector of array.
mlx_vector_int_
A vector of int.
mlx_vector_string_
A vector of string.
mlx_vector_vector_array_
A vector of vector_array.

Enums

mlx_compile_mode_
\defgroup compile Compilation operations / /**@{
mlx_device_type_
Device type.
mlx_dtype_
Array element type.

Functions

mlx_abs(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
\defgroup ops Core array operations / /**@{
mlx_add(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_addmm(Pointer<mlx_array> res, mlx_array c, mlx_array a, mlx_array b, double alpha, double beta, mlx_stream s) int
mlx_all(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_all_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, mlx_stream s) int
mlx_all_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_allclose(Pointer<mlx_array> res, mlx_array a, mlx_array b, double rtol, double atol, bool equal_nan, mlx_stream s) int
mlx_any(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_any_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, mlx_stream s) int
mlx_any_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_arange(Pointer<mlx_array> res, double start, double stop, double step, mlx_dtype_ dtype, mlx_stream s) int
mlx_arccos(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_arccosh(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_arcsin(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_arcsinh(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_arctan(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_arctan2(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_arctanh(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_argmax(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_argmax_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_argmin(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_argmin_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_argpartition(Pointer<mlx_array> res, mlx_array a, int kth, mlx_stream s) int
mlx_argpartition_axis(Pointer<mlx_array> res, mlx_array a, int kth, int axis, mlx_stream s) int
mlx_argsort(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_argsort_axis(Pointer<mlx_array> res, mlx_array a, int axis, mlx_stream s) int
mlx_array_data_bfloat16(mlx_array arr) Pointer<Uint16>
mlx_array_data_bool(mlx_array arr) Pointer<Bool>
Returns a pointer to the array data, cast to bool*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_complex64(mlx_array arr) Pointer<mlx_complex64_t>
mlx_array_data_float16(mlx_array arr) Pointer<Uint16>
mlx_array_data_float32(mlx_array arr) Pointer<Float>
Returns a pointer to the array data, cast to float32*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_float64(mlx_array arr) Pointer<Double>
Returns a pointer to the array data, cast to float64*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_int16(mlx_array arr) Pointer<Int16>
Returns a pointer to the array data, cast to int16_t*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_int32(mlx_array arr) Pointer<Int32>
Returns a pointer to the array data, cast to int32_t*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_int64(mlx_array arr) Pointer<Int64>
Returns a pointer to the array data, cast to int64_t*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_int8(mlx_array arr) Pointer<Int8>
Returns a pointer to the array data, cast to int8_t*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_uint16(mlx_array arr) Pointer<Uint16>
Returns a pointer to the array data, cast to uint16_t*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_uint32(mlx_array arr) Pointer<Uint32>
Returns a pointer to the array data, cast to uint32_t*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_uint64(mlx_array arr) Pointer<Uint64>
Returns a pointer to the array data, cast to uint64_t*. Array must be evaluated, otherwise returns NULL.
mlx_array_data_uint8(mlx_array arr) Pointer<Uint8>
Returns a pointer to the array data, cast to uint8_t*. Array must be evaluated, otherwise returns NULL.
mlx_array_dim(mlx_array arr, int dim) int
The shape of the array in a particular dimension.
mlx_array_dtype(mlx_array arr) mlx_dtype_
mlx_array_equal(Pointer<mlx_array> res, mlx_array a, mlx_array b, bool equal_nan, mlx_stream s) int
mlx_array_eval(mlx_array arr) int
Evaluate the array.
mlx_array_free(mlx_array arr) int
Free an array.
mlx_array_item_bfloat16(Pointer<Uint16> res, mlx_array arr) int
mlx_array_item_bool(Pointer<Bool> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_complex64(Pointer<mlx_complex64_t> res, mlx_array arr) int
mlx_array_item_float16(Pointer<Uint16> res, mlx_array arr) int
mlx_array_item_float32(Pointer<Float> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_float64(Pointer<Double> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_int16(Pointer<Int16> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_int32(Pointer<Int32> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_int64(Pointer<Int64> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_int8(Pointer<Int8> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_uint16(Pointer<Uint16> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_uint32(Pointer<Uint32> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_uint64(Pointer<Uint64> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_item_uint8(Pointer<Uint8> res, mlx_array arr) int
Access the value of a scalar array.
mlx_array_itemsize(mlx_array arr) int
The size of the array's datatype in bytes.
mlx_array_nbytes(mlx_array arr) int
The number of bytes in the array.
mlx_array_ndim(mlx_array arr) int
The array's dimension.
mlx_array_new() mlx_array
New empty array.
mlx_array_new_bool(bool val) mlx_array
New array from a bool scalar.
mlx_array_new_complex(double real_val, double imag_val) mlx_array
New array from a complex scalar.
mlx_array_new_data(Pointer<Void> data, Pointer<Int> shape, int dim, mlx_dtype_ dtype) mlx_array
mlx_array_new_data_managed(Pointer<Void> data, Pointer<Int> shape, int dim, mlx_dtype_ dtype, Pointer<NativeFunction<Void Function(Pointer<Void>)>> dtor) mlx_array
mlx_array_new_data_managed_payload(Pointer<Void> data, Pointer<Int> shape, int dim, mlx_dtype_ dtype, Pointer<Void> payload, Pointer<NativeFunction<Void Function(Pointer<Void>)>> dtor) mlx_array
mlx_array_new_double(double val) mlx_array
New array from a double scalar. Same as float64.
mlx_array_new_float(double val) mlx_array
New array from a float scalar. Same as float32.
mlx_array_new_float32(double val) mlx_array
New array from a float32 scalar.
mlx_array_new_float64(double val) mlx_array
New array from a float64 scalar.
mlx_array_new_int(int val) mlx_array
New array from a int scalar.
mlx_array_set(Pointer<mlx_array> arr, mlx_array src) int
Set array to provided src array.
mlx_array_set_bool(Pointer<mlx_array> arr, bool val) int
Set array to a bool scalar.
mlx_array_set_complex(Pointer<mlx_array> arr, double real_val, double imag_val) int
Set array to a complex scalar.
mlx_array_set_data(Pointer<mlx_array> arr, Pointer<Void> data, Pointer<Int> shape, int dim, mlx_dtype_ dtype) int
mlx_array_set_double(Pointer<mlx_array> arr, double val) int
Set array to a double scalar.
mlx_array_set_float(Pointer<mlx_array> arr, double val) int
Set array to a float scalar.
mlx_array_set_float32(Pointer<mlx_array> arr, double val) int
Set array to a float32 scalar.
mlx_array_set_float64(Pointer<mlx_array> arr, double val) int
Set array to a float64 scalar.
mlx_array_set_int(Pointer<mlx_array> arr, int val) int
Set array to a int scalar.
mlx_array_shape(mlx_array arr) Pointer<Int>
The shape of the array. Returns: a pointer to the sizes of each dimension.
mlx_array_size(mlx_array arr) int
Number of elements in the array.
mlx_array_strides(mlx_array arr) Pointer<Size>
The strides of the array. Returns: a pointer to the sizes of each dimension.
mlx_array_tostring(Pointer<mlx_string> str, mlx_array arr) int
Get array description.
mlx_as_strided(Pointer<mlx_array> res, mlx_array a, Pointer<Int> shape, int shape_num, Pointer<Int64> strides, int strides_num, int offset, mlx_stream s) int
mlx_astype(Pointer<mlx_array> res, mlx_array a, mlx_dtype_ dtype, mlx_stream s) int
mlx_async_eval(mlx_vector_array outputs) int
\defgroup transforms Transform operations / /**@{
mlx_atleast_1d(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_atleast_2d(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_atleast_3d(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_bitwise_and(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_bitwise_invert(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_bitwise_or(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_bitwise_xor(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_blackman(Pointer<mlx_array> res, int M, mlx_stream s) int
mlx_block_masked_mm(Pointer<mlx_array> res, mlx_array a, mlx_array b, int block_size, mlx_array mask_out, mlx_array mask_lhs, mlx_array mask_rhs, mlx_stream s) int
mlx_broadcast_arrays(Pointer<mlx_vector_array> res, mlx_vector_array inputs, mlx_stream s) int
mlx_broadcast_to(Pointer<mlx_array> res, mlx_array a, Pointer<Int> shape, int shape_num, mlx_stream s) int
mlx_ceil(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_checkpoint(Pointer<mlx_closure> res, mlx_closure fun) int
mlx_clear_cache() int
\defgroup memory Memory operations / /**@{
mlx_clip(Pointer<mlx_array> res, mlx_array a, mlx_array a_min, mlx_array a_max, mlx_stream s) int
mlx_closure_apply(Pointer<mlx_vector_array> res, mlx_closure cls, mlx_vector_array input) int
mlx_closure_custom_apply(Pointer<mlx_vector_array> res, mlx_closure_custom cls, mlx_vector_array input_0, mlx_vector_array input_1, mlx_vector_array input_2) int
mlx_closure_custom_free(mlx_closure_custom cls) int
mlx_closure_custom_jvp_apply(Pointer<mlx_vector_array> res, mlx_closure_custom_jvp cls, mlx_vector_array input_0, mlx_vector_array input_1, Pointer<Int> input_2, int input_2_num) int
mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) int
mlx_closure_custom_jvp_new() mlx_closure_custom_jvp
mlx_closure_custom_jvp_new_func(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, mlx_vector_array, mlx_vector_array, Pointer<Int>, Size)>> fun) mlx_closure_custom_jvp
mlx_closure_custom_jvp_new_func_payload(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, mlx_vector_array, mlx_vector_array, Pointer<Int>, Size, Pointer<Void>)>> fun, Pointer<Void> payload, Pointer<NativeFunction<Void Function(Pointer<Void>)>> dtor) mlx_closure_custom_jvp
mlx_closure_custom_jvp_set(Pointer<mlx_closure_custom_jvp> cls, mlx_closure_custom_jvp src) int
mlx_closure_custom_new() mlx_closure_custom
mlx_closure_custom_new_func(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, mlx_vector_array, mlx_vector_array, mlx_vector_array)>> fun) mlx_closure_custom
mlx_closure_custom_new_func_payload(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, mlx_vector_array, mlx_vector_array, mlx_vector_array, Pointer<Void>)>> fun, Pointer<Void> payload, Pointer<NativeFunction<Void Function(Pointer<Void>)>> dtor) mlx_closure_custom
mlx_closure_custom_set(Pointer<mlx_closure_custom> cls, mlx_closure_custom src) int
mlx_closure_custom_vmap_apply(Pointer<mlx_vector_array> res_0, Pointer<mlx_vector_int> res_1, mlx_closure_custom_vmap cls, mlx_vector_array input_0, Pointer<Int> input_1, int input_1_num) int
mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) int
mlx_closure_custom_vmap_new() mlx_closure_custom_vmap
mlx_closure_custom_vmap_new_func(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, Pointer<mlx_vector_int>, mlx_vector_array, Pointer<Int>, Size)>> fun) mlx_closure_custom_vmap
mlx_closure_custom_vmap_new_func_payload(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, Pointer<mlx_vector_int>, mlx_vector_array, Pointer<Int>, Size, Pointer<Void>)>> fun, Pointer<Void> payload, Pointer<NativeFunction<Void Function(Pointer<Void>)>> dtor) mlx_closure_custom_vmap
mlx_closure_custom_vmap_set(Pointer<mlx_closure_custom_vmap> cls, mlx_closure_custom_vmap src) int
mlx_closure_free(mlx_closure cls) int
mlx_closure_kwargs_apply(Pointer<mlx_vector_array> res, mlx_closure_kwargs cls, mlx_vector_array input_0, mlx_map_string_to_array input_1) int
mlx_closure_kwargs_free(mlx_closure_kwargs cls) int
mlx_closure_kwargs_new() mlx_closure_kwargs
mlx_closure_kwargs_new_func(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, mlx_vector_array, mlx_map_string_to_array)>> fun) mlx_closure_kwargs
mlx_closure_kwargs_new_func_payload(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, mlx_vector_array, mlx_map_string_to_array, Pointer<Void>)>> fun, Pointer<Void> payload, Pointer<NativeFunction<Void Function(Pointer<Void>)>> dtor) mlx_closure_kwargs
mlx_closure_kwargs_set(Pointer<mlx_closure_kwargs> cls, mlx_closure_kwargs src) int
mlx_closure_new() mlx_closure
mlx_closure_new_func(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, mlx_vector_array)>> fun) mlx_closure
mlx_closure_new_func_payload(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, mlx_vector_array, Pointer<Void>)>> fun, Pointer<Void> payload, Pointer<NativeFunction<Void Function(Pointer<Void>)>> dtor) mlx_closure
mlx_closure_new_unary(Pointer<NativeFunction<Int Function(Pointer<mlx_array>, mlx_array)>> fun) mlx_closure
mlx_closure_set(Pointer<mlx_closure> cls, mlx_closure src) int
mlx_closure_value_and_grad_apply(Pointer<mlx_vector_array> res_0, Pointer<mlx_vector_array> res_1, mlx_closure_value_and_grad cls, mlx_vector_array input) int
mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) int
mlx_closure_value_and_grad_new() mlx_closure_value_and_grad
mlx_closure_value_and_grad_new_func(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, Pointer<mlx_vector_array>, mlx_vector_array)>> fun) mlx_closure_value_and_grad
mlx_closure_value_and_grad_new_func_payload(Pointer<NativeFunction<Int Function(Pointer<mlx_vector_array>, Pointer<mlx_vector_array>, mlx_vector_array, Pointer<Void>)>> fun, Pointer<Void> payload, Pointer<NativeFunction<Void Function(Pointer<Void>)>> dtor) mlx_closure_value_and_grad
mlx_closure_value_and_grad_set(Pointer<mlx_closure_value_and_grad> cls, mlx_closure_value_and_grad src) int
mlx_compile(Pointer<mlx_closure> res, mlx_closure fun, bool shapeless) int
mlx_concatenate(Pointer<mlx_array> res, mlx_vector_array arrays, mlx_stream s) int
mlx_concatenate_axis(Pointer<mlx_array> res, mlx_vector_array arrays, int axis, mlx_stream s) int
mlx_conjugate(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_contiguous(Pointer<mlx_array> res, mlx_array a, bool allow_col_major, mlx_stream s) int
mlx_conv1d(Pointer<mlx_array> res, mlx_array input, mlx_array weight, int stride, int padding, int dilation, int groups, mlx_stream s) int
mlx_conv2d(Pointer<mlx_array> res, mlx_array input, mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int groups, mlx_stream s) int
mlx_conv3d(Pointer<mlx_array> res, mlx_array input, mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int groups, mlx_stream s) int
mlx_conv_general(Pointer<mlx_array> res, mlx_array input, mlx_array weight, Pointer<Int> stride, int stride_num, Pointer<Int> padding_lo, int padding_lo_num, Pointer<Int> padding_hi, int padding_hi_num, Pointer<Int> kernel_dilation, int kernel_dilation_num, Pointer<Int> input_dilation, int input_dilation_num, int groups, bool flip, mlx_stream s) int
mlx_conv_transpose1d(Pointer<mlx_array> res, mlx_array input, mlx_array weight, int stride, int padding, int dilation, int output_padding, int groups, mlx_stream s) int
mlx_conv_transpose2d(Pointer<mlx_array> res, mlx_array input, mlx_array weight, int stride_0, int stride_1, int padding_0, int padding_1, int dilation_0, int dilation_1, int output_padding_0, int output_padding_1, int groups, mlx_stream s) int
mlx_conv_transpose3d(Pointer<mlx_array> res, mlx_array input, mlx_array weight, int stride_0, int stride_1, int stride_2, int padding_0, int padding_1, int padding_2, int dilation_0, int dilation_1, int dilation_2, int output_padding_0, int output_padding_1, int output_padding_2, int groups, mlx_stream s) int
mlx_copy(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_cos(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_cosh(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_cuda_is_available(Pointer<Bool> res) int
\defgroup cuda Cuda specific operations / /**@{
mlx_cummax(Pointer<mlx_array> res, mlx_array a, int axis, bool reverse, bool inclusive, mlx_stream s) int
mlx_cummin(Pointer<mlx_array> res, mlx_array a, int axis, bool reverse, bool inclusive, mlx_stream s) int
mlx_cumprod(Pointer<mlx_array> res, mlx_array a, int axis, bool reverse, bool inclusive, mlx_stream s) int
mlx_cumsum(Pointer<mlx_array> res, mlx_array a, int axis, bool reverse, bool inclusive, mlx_stream s) int
mlx_custom_function(Pointer<mlx_closure> res, mlx_closure fun, mlx_closure_custom fun_vjp, mlx_closure_custom_jvp fun_jvp, mlx_closure_custom_vmap fun_vmap) int
mlx_custom_vjp(Pointer<mlx_closure> res, mlx_closure fun, mlx_closure_custom fun_vjp) int
mlx_default_cpu_stream_new() mlx_stream
Returns the current default CPU stream.
mlx_default_gpu_stream_new() mlx_stream
Returns the current default GPU stream.
mlx_degrees(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_depends(Pointer<mlx_vector_array> res, mlx_vector_array inputs, mlx_vector_array dependencies) int
mlx_dequantize(Pointer<mlx_array> res, mlx_array w, mlx_array scales, mlx_array biases, mlx_optional_int group_size, mlx_optional_int bits, Pointer<Char> mode, mlx_array global_scale, mlx_optional_dtype dtype, mlx_stream s) int
mlx_detail_compile(Pointer<mlx_closure> res, mlx_closure fun, int fun_id, bool shapeless, Pointer<Uint64> constants, int constants_num) int
mlx_detail_compile_clear_cache() int
mlx_detail_compile_erase(int fun_id) int
mlx_detail_vmap_replace(Pointer<mlx_vector_array> res, mlx_vector_array inputs, mlx_vector_array s_inputs, mlx_vector_array s_outputs, Pointer<Int> in_axes, int in_axes_num, Pointer<Int> out_axes, int out_axes_num) int
\defgroup transforms_impl Implementation detail operations / /**@{
mlx_detail_vmap_trace(Pointer<mlx_vector_array> res_0, Pointer<mlx_vector_array> res_1, mlx_closure fun, mlx_vector_array inputs, Pointer<Int> in_axes, int in_axes_num) int
mlx_device_count(Pointer<Int> count, mlx_device_type_ type) int
mlx_device_equal(mlx_device lhs, mlx_device rhs) bool
Check if devices are the same.
mlx_device_free(mlx_device dev) int
Free a device.
mlx_device_get_index(Pointer<Int> index, mlx_device dev) int
Returns the index of the device.
mlx_device_get_type(Pointer<UnsignedInt> type, mlx_device dev) int
Returns the type of the device.
mlx_device_info_free(mlx_device_info info) int
Free a device info object.
mlx_device_info_get(Pointer<mlx_device_info> info, mlx_device dev) int
Get device information for a device.
mlx_device_info_get_keys(Pointer<mlx_vector_string> keys, mlx_device_info info) int
Get all keys from device info. Returns 0 on success, 1 on error.
mlx_device_info_get_size(Pointer<Size> value, mlx_device_info info, Pointer<Char> key) int
Get a size_t value from device info. Returns 0 on success, 1 on error, 2 if key not found or wrong type.
mlx_device_info_get_string(Pointer<Pointer<Char>> value, mlx_device_info info, Pointer<Char> key) int
Get a string value from device info. Returns 0 on success, 1 on error, 2 if key not found or wrong type.
mlx_device_info_has_key(Pointer<Bool> exists, mlx_device_info info, Pointer<Char> key) int
Check if a key exists in the device info. Returns 0 on success, 1 on error. Sets *exists to true if the key exists, false otherwise.
mlx_device_info_is_string(Pointer<Bool> is_string, mlx_device_info info, Pointer<Char> key) int
Check if a value is a string type. Returns 0 on success, 1 on error. Sets *is_string to true if the value is a string, false if it's a size_t.
mlx_device_info_new() mlx_device_info
Returns a new empty device info object.
mlx_device_is_available(Pointer<Bool> avail, mlx_device dev) int
Check if device is available.
mlx_device_new() mlx_device
Returns a new empty device.
mlx_device_new_type(mlx_device_type_ type, int index) mlx_device
mlx_device_set(Pointer<mlx_device> dev, mlx_device src) int
Set device to provided src device.
mlx_device_tostring(Pointer<mlx_string> str, mlx_device dev) int
Get device description.
mlx_diag(Pointer<mlx_array> res, mlx_array a, int k, mlx_stream s) int
mlx_diagonal(Pointer<mlx_array> res, mlx_array a, int offset, int axis1, int axis2, mlx_stream s) int
mlx_disable_compile() int
mlx_distributed_all_gather(Pointer<mlx_array> res, mlx_array x, mlx_distributed_group group, mlx_stream S) int
\defgroup distributed Distributed collectives / /**@{
mlx_distributed_all_max(Pointer<mlx_array> res, mlx_array x, mlx_distributed_group group, mlx_stream s) int
mlx_distributed_all_min(Pointer<mlx_array> res, mlx_array x, mlx_distributed_group group, mlx_stream s) int
mlx_distributed_all_sum(Pointer<mlx_array> res, mlx_array x, mlx_distributed_group group, mlx_stream s) int
mlx_distributed_group_rank(mlx_distributed_group group) int
Get the rank.
mlx_distributed_group_size(mlx_distributed_group group) int
Get the group size.
mlx_distributed_group_split(mlx_distributed_group group, int color, int key) mlx_distributed_group
Split the group.
mlx_distributed_init(bool strict) mlx_distributed_group
Initialize distributed.
mlx_distributed_is_available() bool
Check if distributed is available.
mlx_distributed_recv(Pointer<mlx_array> res, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, int src, mlx_distributed_group group, mlx_stream s) int
mlx_distributed_recv_like(Pointer<mlx_array> res, mlx_array x, int src, mlx_distributed_group group, mlx_stream s) int
mlx_distributed_send(Pointer<mlx_array> res, mlx_array x, int dst, mlx_distributed_group group, mlx_stream s) int
mlx_distributed_sum_scatter(Pointer<mlx_array> res, mlx_array x, mlx_distributed_group group, mlx_stream s) int
mlx_divide(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_divmod(Pointer<mlx_vector_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_dtype_size(mlx_dtype_ dtype) int
mlx_einsum(Pointer<mlx_array> res, Pointer<Char> subscripts, mlx_vector_array operands, mlx_stream s) int
mlx_enable_compile() int
mlx_equal(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_erf(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_erfinv(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_eval(mlx_vector_array outputs) int
mlx_exp(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_expand_dims(Pointer<mlx_array> res, mlx_array a, int axis, mlx_stream s) int
mlx_expand_dims_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_expm1(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_export_function(Pointer<Char> file, mlx_closure fun, mlx_vector_array args, bool shapeless) int
\defgroup export Function serialization / /**@{
mlx_export_function_kwargs(Pointer<Char> file, mlx_closure_kwargs fun, mlx_vector_array args, mlx_map_string_to_array kwargs, bool shapeless) int
mlx_eye(Pointer<mlx_array> res, int n, int m, int k, mlx_dtype_ dtype, mlx_stream s) int
mlx_fast_cuda_kernel_apply(Pointer<mlx_vector_array> outputs, mlx_fast_cuda_kernel cls, mlx_vector_array inputs, mlx_fast_cuda_kernel_config config, mlx_stream stream) int
mlx_fast_cuda_kernel_config_add_output_arg(mlx_fast_cuda_kernel_config cls, Pointer<Int> shape, int size, mlx_dtype_ dtype) int
mlx_fast_cuda_kernel_config_add_template_arg_bool(mlx_fast_cuda_kernel_config cls, Pointer<Char> name, bool value) int
mlx_fast_cuda_kernel_config_add_template_arg_dtype(mlx_fast_cuda_kernel_config cls, Pointer<Char> name, mlx_dtype_ dtype) int
mlx_fast_cuda_kernel_config_add_template_arg_int(mlx_fast_cuda_kernel_config cls, Pointer<Char> name, int value) int
mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) → void
mlx_fast_cuda_kernel_config_new() mlx_fast_cuda_kernel_config
mlx_fast_cuda_kernel_config_set_grid(mlx_fast_cuda_kernel_config cls, int grid1, int grid2, int grid3) int
mlx_fast_cuda_kernel_config_set_init_value(mlx_fast_cuda_kernel_config cls, double value) int
mlx_fast_cuda_kernel_config_set_thread_group(mlx_fast_cuda_kernel_config cls, int thread1, int thread2, int thread3) int
mlx_fast_cuda_kernel_config_set_verbose(mlx_fast_cuda_kernel_config cls, bool verbose) int
mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) → void
mlx_fast_cuda_kernel_new(Pointer<Char> name, mlx_vector_string input_names, mlx_vector_string output_names, Pointer<Char> source, Pointer<Char> header, bool ensure_row_contiguous, int shared_memory) mlx_fast_cuda_kernel
mlx_fast_layer_norm(Pointer<mlx_array> res, mlx_array x, mlx_array weight, mlx_array bias, double eps, mlx_stream s) int
mlx_fast_metal_kernel_apply(Pointer<mlx_vector_array> outputs, mlx_fast_metal_kernel cls, mlx_vector_array inputs, mlx_fast_metal_kernel_config config, mlx_stream stream) int
mlx_fast_metal_kernel_config_add_output_arg(mlx_fast_metal_kernel_config cls, Pointer<Int> shape, int size, mlx_dtype_ dtype) int
mlx_fast_metal_kernel_config_add_template_arg_bool(mlx_fast_metal_kernel_config cls, Pointer<Char> name, bool value) int
mlx_fast_metal_kernel_config_add_template_arg_dtype(mlx_fast_metal_kernel_config cls, Pointer<Char> name, mlx_dtype_ dtype) int
mlx_fast_metal_kernel_config_add_template_arg_int(mlx_fast_metal_kernel_config cls, Pointer<Char> name, int value) int
mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) → void
mlx_fast_metal_kernel_config_new() mlx_fast_metal_kernel_config
mlx_fast_metal_kernel_config_set_grid(mlx_fast_metal_kernel_config cls, int grid1, int grid2, int grid3) int
mlx_fast_metal_kernel_config_set_init_value(mlx_fast_metal_kernel_config cls, double value) int
mlx_fast_metal_kernel_config_set_thread_group(mlx_fast_metal_kernel_config cls, int thread1, int thread2, int thread3) int
mlx_fast_metal_kernel_config_set_verbose(mlx_fast_metal_kernel_config cls, bool verbose) int
mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) → void
mlx_fast_metal_kernel_new(Pointer<Char> name, mlx_vector_string input_names, mlx_vector_string output_names, Pointer<Char> source, Pointer<Char> header, bool ensure_row_contiguous, bool atomic_outputs) mlx_fast_metal_kernel
mlx_fast_rms_norm(Pointer<mlx_array> res, mlx_array x, mlx_array weight, double eps, mlx_stream s) int
mlx_fast_rope(Pointer<mlx_array> res, mlx_array x, int dims, bool traditional, mlx_optional_float base, double scale, int offset, mlx_array freqs, mlx_stream s) int
mlx_fast_rope_dynamic(Pointer<mlx_array> res, mlx_array x, int dims, bool traditional, mlx_optional_float base, double scale, mlx_array offset, mlx_array freqs, mlx_stream s) int
mlx_fast_scaled_dot_product_attention(Pointer<mlx_array> res, mlx_array queries, mlx_array keys, mlx_array values, double scale, Pointer<Char> mask_mode, mlx_array mask_arr, mlx_array sinks, mlx_stream s) int
mlx_fft_fft(Pointer<mlx_array> res, mlx_array a, int n, int axis, mlx_stream s) int
\defgroup fft FFT operations / /**@{
mlx_fft_fft2(Pointer<mlx_array> res, mlx_array a, Pointer<Int> n, int n_num, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_fft_fftn(Pointer<mlx_array> res, mlx_array a, Pointer<Int> n, int n_num, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_fft_fftshift(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_fft_ifft(Pointer<mlx_array> res, mlx_array a, int n, int axis, mlx_stream s) int
mlx_fft_ifft2(Pointer<mlx_array> res, mlx_array a, Pointer<Int> n, int n_num, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_fft_ifftn(Pointer<mlx_array> res, mlx_array a, Pointer<Int> n, int n_num, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_fft_ifftshift(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_fft_irfft(Pointer<mlx_array> res, mlx_array a, int n, int axis, mlx_stream s) int
mlx_fft_irfft2(Pointer<mlx_array> res, mlx_array a, Pointer<Int> n, int n_num, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_fft_irfftn(Pointer<mlx_array> res, mlx_array a, Pointer<Int> n, int n_num, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_fft_rfft(Pointer<mlx_array> res, mlx_array a, int n, int axis, mlx_stream s) int
mlx_fft_rfft2(Pointer<mlx_array> res, mlx_array a, Pointer<Int> n, int n_num, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_fft_rfftn(Pointer<mlx_array> res, mlx_array a, Pointer<Int> n, int n_num, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_flatten(Pointer<mlx_array> res, mlx_array a, int start_axis, int end_axis, mlx_stream s) int
mlx_floor(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_floor_divide(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_from_fp8(Pointer<mlx_array> res, mlx_array x, mlx_dtype_ dtype, mlx_stream s) int
mlx_full(Pointer<mlx_array> res, Pointer<Int> shape, int shape_num, mlx_array vals, mlx_dtype_ dtype, mlx_stream s) int
mlx_full_like(Pointer<mlx_array> res, mlx_array a, mlx_array vals, mlx_dtype_ dtype, mlx_stream s) int
mlx_function_exporter_apply(mlx_function_exporter xfunc, mlx_vector_array args) int
mlx_function_exporter_apply_kwargs(mlx_function_exporter xfunc, mlx_vector_array args, mlx_map_string_to_array kwargs) int
mlx_function_exporter_free(mlx_function_exporter xfunc) int
mlx_function_exporter_new(Pointer<Char> file, mlx_closure fun, bool shapeless) mlx_function_exporter
mlx_gather(Pointer<mlx_array> res, mlx_array a, mlx_vector_array indices, Pointer<Int> axes, int axes_num, Pointer<Int> slice_sizes, int slice_sizes_num, mlx_stream s) int
mlx_gather_mm(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_array lhs_indices, mlx_array rhs_indices, bool sorted_indices, mlx_stream s) int
mlx_gather_qmm(Pointer<mlx_array> res, mlx_array x, mlx_array w, mlx_array scales, mlx_array biases, mlx_array lhs_indices, mlx_array rhs_indices, bool transpose, mlx_optional_int group_size, mlx_optional_int bits, Pointer<Char> mode, bool sorted_indices, mlx_stream s) int
mlx_gather_single(Pointer<mlx_array> res, mlx_array a, mlx_array indices, int axis, Pointer<Int> slice_sizes, int slice_sizes_num, mlx_stream s) int
mlx_get_active_memory(Pointer<Size> res) int
mlx_get_cache_memory(Pointer<Size> res) int
mlx_get_default_device(Pointer<mlx_device> dev) int
Returns the default MLX device.
mlx_get_default_stream(Pointer<mlx_stream> stream, mlx_device dev) int
Returns the default stream on the given device.
mlx_get_memory_limit(Pointer<Size> res) int
mlx_get_peak_memory(Pointer<Size> res) int
mlx_greater(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_greater_equal(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_hadamard_transform(Pointer<mlx_array> res, mlx_array a, mlx_optional_float scale, mlx_stream s) int
mlx_hamming(Pointer<mlx_array> res, int M, mlx_stream s) int
mlx_hanning(Pointer<mlx_array> res, int M, mlx_stream s) int
mlx_identity(Pointer<mlx_array> res, int n, mlx_dtype_ dtype, mlx_stream s) int
mlx_imag(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_imported_function_apply(Pointer<mlx_vector_array> res, mlx_imported_function xfunc, mlx_vector_array args) int
mlx_imported_function_apply_kwargs(Pointer<mlx_vector_array> res, mlx_imported_function xfunc, mlx_vector_array args, mlx_map_string_to_array kwargs) int
mlx_imported_function_free(mlx_imported_function xfunc) int
mlx_imported_function_new(Pointer<Char> file) mlx_imported_function
mlx_inner(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_io_reader_descriptor(Pointer<Pointer<Void>> desc_, mlx_io_reader io) int
Get IO reader user descriptor.
mlx_io_reader_free(mlx_io_reader io) int
Free IO reader.
mlx_io_reader_new(Pointer<Void> desc, mlx_io_vtable vtable) mlx_io_reader
Returns a new custom IO reader. vtable operates on user descriptor desc.
mlx_io_reader_tostring(Pointer<mlx_string> str_, mlx_io_reader io) int
Get IO reader description.
mlx_io_writer_descriptor(Pointer<Pointer<Void>> desc_, mlx_io_writer io) int
Get IO writer user descriptor.
mlx_io_writer_free(mlx_io_writer io) int
Free IO writer.
mlx_io_writer_new(Pointer<Void> desc, mlx_io_vtable vtable) mlx_io_writer
Returns a new custom IO writer. vtable operates on user descriptor desc.
mlx_io_writer_tostring(Pointer<mlx_string> str_, mlx_io_writer io) int
Get IO writer description.
mlx_isclose(Pointer<mlx_array> res, mlx_array a, mlx_array b, double rtol, double atol, bool equal_nan, mlx_stream s) int
mlx_isfinite(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_isinf(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_isnan(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_isneginf(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_isposinf(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_jvp(Pointer<mlx_vector_array> res_0, Pointer<mlx_vector_array> res_1, mlx_closure fun, mlx_vector_array primals, mlx_vector_array tangents) int
mlx_kron(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_left_shift(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_less(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_less_equal(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_linalg_cholesky(Pointer<mlx_array> res, mlx_array a, bool upper, mlx_stream s) int
\defgroup linalg Linear algebra operations / /**@{
mlx_linalg_cholesky_inv(Pointer<mlx_array> res, mlx_array a, bool upper, mlx_stream s) int
mlx_linalg_cross(Pointer<mlx_array> res, mlx_array a, mlx_array b, int axis, mlx_stream s) int
mlx_linalg_eig(Pointer<mlx_array> res_0, Pointer<mlx_array> res_1, mlx_array a, mlx_stream s) int
mlx_linalg_eigh(Pointer<mlx_array> res_0, Pointer<mlx_array> res_1, mlx_array a, Pointer<Char> UPLO, mlx_stream s) int
mlx_linalg_eigvals(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_linalg_eigvalsh(Pointer<mlx_array> res, mlx_array a, Pointer<Char> UPLO, mlx_stream s) int
mlx_linalg_inv(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_linalg_lu(Pointer<mlx_vector_array> res, mlx_array a, mlx_stream s) int
mlx_linalg_lu_factor(Pointer<mlx_array> res_0, Pointer<mlx_array> res_1, mlx_array a, mlx_stream s) int
mlx_linalg_norm(Pointer<mlx_array> res, mlx_array a, double ord, Pointer<Int> axis, int axis_num, bool keepdims, mlx_stream s) int
mlx_linalg_norm_l2(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axis, int axis_num, bool keepdims, mlx_stream s) int
mlx_linalg_norm_matrix(Pointer<mlx_array> res, mlx_array a, Pointer<Char> ord, Pointer<Int> axis, int axis_num, bool keepdims, mlx_stream s) int
mlx_linalg_pinv(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_linalg_qr(Pointer<mlx_array> res_0, Pointer<mlx_array> res_1, mlx_array a, mlx_stream s) int
mlx_linalg_solve(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_linalg_solve_triangular(Pointer<mlx_array> res, mlx_array a, mlx_array b, bool upper, mlx_stream s) int
mlx_linalg_svd(Pointer<mlx_vector_array> res, mlx_array a, bool compute_uv, mlx_stream s) int
mlx_linalg_tri_inv(Pointer<mlx_array> res, mlx_array a, bool upper, mlx_stream s) int
mlx_linspace(Pointer<mlx_array> res, double start, double stop, int num, mlx_dtype_ dtype, mlx_stream s) int
mlx_load(Pointer<mlx_array> res, Pointer<Char> file, mlx_stream s) int
mlx_load_reader(Pointer<mlx_array> res, mlx_io_reader in_stream, mlx_stream s) int
\defgroup io IO operations / /**@{
mlx_load_safetensors(Pointer<mlx_map_string_to_array> res_0, Pointer<mlx_map_string_to_string> res_1, Pointer<Char> file, mlx_stream s) int
mlx_load_safetensors_reader(Pointer<mlx_map_string_to_array> res_0, Pointer<mlx_map_string_to_string> res_1, mlx_io_reader in_stream, mlx_stream s) int
mlx_log(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_log10(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_log1p(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_log2(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_logaddexp(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_logcumsumexp(Pointer<mlx_array> res, mlx_array a, int axis, bool reverse, bool inclusive, mlx_stream s) int
mlx_logical_and(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_logical_not(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_logical_or(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_logsumexp(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_logsumexp_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, mlx_stream s) int
mlx_logsumexp_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_map_string_to_array_free(mlx_map_string_to_array map) int
Free a string-to-array map.
mlx_map_string_to_array_get(Pointer<mlx_array> value, mlx_map_string_to_array map, Pointer<Char> key) int
Returns the value indexed at the specified key in the map.
mlx_map_string_to_array_insert(mlx_map_string_to_array map, Pointer<Char> key, mlx_array value) int
Insert a new value at the specified key in the map.
mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) int
Free iterator.
mlx_map_string_to_array_iterator_new(mlx_map_string_to_array map) mlx_map_string_to_array_iterator
Returns a new iterator over the given map.
mlx_map_string_to_array_iterator_next(Pointer<Pointer<Char>> key, Pointer<mlx_array> value, mlx_map_string_to_array_iterator it) int
Increment iterator.
mlx_map_string_to_array_new() mlx_map_string_to_array
Returns a new empty string-to-array map.
mlx_map_string_to_array_set(Pointer<mlx_map_string_to_array> map, mlx_map_string_to_array src) int
Set map to provided src map.
mlx_map_string_to_string_free(mlx_map_string_to_string map) int
Free a string-to-string map.
mlx_map_string_to_string_get(Pointer<Pointer<Char>> value, mlx_map_string_to_string map, Pointer<Char> key) int
Returns the value indexed at the specified key in the map.
mlx_map_string_to_string_insert(mlx_map_string_to_string map, Pointer<Char> key, Pointer<Char> value) int
Insert a new value at the specified key in the map.
mlx_map_string_to_string_iterator_free(mlx_map_string_to_string_iterator it) int
Free iterator.
mlx_map_string_to_string_iterator_new(mlx_map_string_to_string map) mlx_map_string_to_string_iterator
Returns a new iterator over the given map.
mlx_map_string_to_string_iterator_next(Pointer<Pointer<Char>> key, Pointer<Pointer<Char>> value, mlx_map_string_to_string_iterator it) int
Increment iterator.
mlx_map_string_to_string_new() mlx_map_string_to_string
Returns a new empty string-to-string map.
mlx_map_string_to_string_set(Pointer<mlx_map_string_to_string> map, mlx_map_string_to_string src) int
Set map to provided src map.
mlx_masked_scatter(Pointer<mlx_array> res, mlx_array a, mlx_array mask, mlx_array src, mlx_stream s) int
mlx_matmul(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_max(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_max_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, mlx_stream s) int
mlx_max_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_maximum(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_mean(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_mean_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, mlx_stream s) int
mlx_mean_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_median(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, mlx_stream s) int
mlx_meshgrid(Pointer<mlx_vector_array> res, mlx_vector_array arrays, bool sparse, Pointer<Char> indexing, mlx_stream s) int
mlx_metal_is_available(Pointer<Bool> res) int
\defgroup metal Metal specific operations / /**@{
mlx_metal_start_capture(Pointer<Char> path) int
mlx_metal_stop_capture() int
mlx_min(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_min_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, mlx_stream s) int
mlx_min_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_minimum(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_moveaxis(Pointer<mlx_array> res, mlx_array a, int source, int destination, mlx_stream s) int
mlx_multiply(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_nan_to_num(Pointer<mlx_array> res, mlx_array a, double nan, mlx_optional_float posinf, mlx_optional_float neginf, mlx_stream s) int
mlx_negative(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_not_equal(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_number_of_elements(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool inverted, mlx_dtype_ dtype, mlx_stream s) int
mlx_ones(Pointer<mlx_array> res, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, mlx_stream s) int
mlx_ones_like(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_outer(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_pad(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, Pointer<Int> low_pad_size, int low_pad_size_num, Pointer<Int> high_pad_size, int high_pad_size_num, mlx_array pad_value, Pointer<Char> mode, mlx_stream s) int
mlx_pad_symmetric(Pointer<mlx_array> res, mlx_array a, int pad_width, mlx_array pad_value, Pointer<Char> mode, mlx_stream s) int
mlx_partition(Pointer<mlx_array> res, mlx_array a, int kth, mlx_stream s) int
mlx_partition_axis(Pointer<mlx_array> res, mlx_array a, int kth, int axis, mlx_stream s) int
mlx_power(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_prod(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_prod_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, mlx_stream s) int
mlx_prod_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_put_along_axis(Pointer<mlx_array> res, mlx_array a, mlx_array indices, mlx_array values, int axis, mlx_stream s) int
mlx_qqmm(Pointer<mlx_array> res, mlx_array x, mlx_array w, mlx_array w_scales, mlx_optional_int group_size, mlx_optional_int bits, Pointer<Char> mode, mlx_array global_scale_x, mlx_array global_scale_w, mlx_stream s) int
mlx_quantize(Pointer<mlx_vector_array> res, mlx_array w, mlx_optional_int group_size, mlx_optional_int bits, Pointer<Char> mode, mlx_array global_scale, mlx_stream s) int
mlx_quantized_matmul(Pointer<mlx_array> res, mlx_array x, mlx_array w, mlx_array scales, mlx_array biases, bool transpose, mlx_optional_int group_size, mlx_optional_int bits, Pointer<Char> mode, mlx_stream s) int
mlx_radians(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_random_bernoulli(Pointer<mlx_array> res, mlx_array p, Pointer<Int> shape, int shape_num, mlx_array key, mlx_stream s) int
\defgroup random Random number operations / /**@{
mlx_random_bits(Pointer<mlx_array> res, Pointer<Int> shape, int shape_num, int width, mlx_array key, mlx_stream s) int
mlx_random_categorical(Pointer<mlx_array> res, mlx_array logits, int axis, mlx_array key, mlx_stream s) int
mlx_random_categorical_num_samples(Pointer<mlx_array> res, mlx_array logits_, int axis, int num_samples, mlx_array key, mlx_stream s) int
mlx_random_categorical_shape(Pointer<mlx_array> res, mlx_array logits, int axis, Pointer<Int> shape, int shape_num, mlx_array key, mlx_stream s) int
mlx_random_gumbel(Pointer<mlx_array> res, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, mlx_array key, mlx_stream s) int
mlx_random_key(Pointer<mlx_array> res, int seed) int
mlx_random_laplace(Pointer<mlx_array> res, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, double loc, double scale, mlx_array key, mlx_stream s) int
mlx_random_multivariate_normal(Pointer<mlx_array> res, mlx_array mean, mlx_array cov, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, mlx_array key, mlx_stream s) int
mlx_random_normal(Pointer<mlx_array> res, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, double loc, double scale, mlx_array key, mlx_stream s) int
mlx_random_normal_broadcast(Pointer<mlx_array> res, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, mlx_array loc, mlx_array scale, mlx_array key, mlx_stream s) int
mlx_random_permutation(Pointer<mlx_array> res, mlx_array x, int axis, mlx_array key, mlx_stream s) int
mlx_random_permutation_arange(Pointer<mlx_array> res, int x, mlx_array key, mlx_stream s) int
mlx_random_randint(Pointer<mlx_array> res, mlx_array low, mlx_array high, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, mlx_array key, mlx_stream s) int
mlx_random_seed(int seed) int
mlx_random_split(Pointer<mlx_array> res_0, Pointer<mlx_array> res_1, mlx_array key, mlx_stream s) int
mlx_random_split_num(Pointer<mlx_array> res, mlx_array key, int num, mlx_stream s) int
mlx_random_truncated_normal(Pointer<mlx_array> res, mlx_array lower, mlx_array upper, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, mlx_array key, mlx_stream s) int
mlx_random_uniform(Pointer<mlx_array> res, mlx_array low, mlx_array high, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, mlx_array key, mlx_stream s) int
mlx_real(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_reciprocal(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_remainder(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_repeat(Pointer<mlx_array> res, mlx_array arr, int repeats, mlx_stream s) int
mlx_repeat_axis(Pointer<mlx_array> res, mlx_array arr, int repeats, int axis, mlx_stream s) int
mlx_reset_peak_memory() int
mlx_reshape(Pointer<mlx_array> res, mlx_array a, Pointer<Int> shape, int shape_num, mlx_stream s) int
mlx_right_shift(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_roll(Pointer<mlx_array> res, mlx_array a, Pointer<Int> shift, int shift_num, mlx_stream s) int
mlx_roll_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> shift, int shift_num, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_roll_axis(Pointer<mlx_array> res, mlx_array a, Pointer<Int> shift, int shift_num, int axis, mlx_stream s) int
mlx_round(Pointer<mlx_array> res, mlx_array a, int decimals, mlx_stream s) int
mlx_rsqrt(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_save(Pointer<Char> file, mlx_array a) int
mlx_save_safetensors(Pointer<Char> file, mlx_map_string_to_array param, mlx_map_string_to_string metadata) int
mlx_save_safetensors_writer(mlx_io_writer in_stream, mlx_map_string_to_array param, mlx_map_string_to_string metadata) int
mlx_save_writer(mlx_io_writer out_stream, mlx_array a) int
mlx_scatter(Pointer<mlx_array> res, mlx_array a, mlx_vector_array indices, mlx_array updates, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_scatter_add(Pointer<mlx_array> res, mlx_array a, mlx_vector_array indices, mlx_array updates, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_scatter_add_axis(Pointer<mlx_array> res, mlx_array a, mlx_array indices, mlx_array values, int axis, mlx_stream s) int
mlx_scatter_add_single(Pointer<mlx_array> res, mlx_array a, mlx_array indices, mlx_array updates, int axis, mlx_stream s) int
mlx_scatter_max(Pointer<mlx_array> res, mlx_array a, mlx_vector_array indices, mlx_array updates, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_scatter_max_single(Pointer<mlx_array> res, mlx_array a, mlx_array indices, mlx_array updates, int axis, mlx_stream s) int
mlx_scatter_min(Pointer<mlx_array> res, mlx_array a, mlx_vector_array indices, mlx_array updates, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_scatter_min_single(Pointer<mlx_array> res, mlx_array a, mlx_array indices, mlx_array updates, int axis, mlx_stream s) int
mlx_scatter_prod(Pointer<mlx_array> res, mlx_array a, mlx_vector_array indices, mlx_array updates, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_scatter_prod_single(Pointer<mlx_array> res, mlx_array a, mlx_array indices, mlx_array updates, int axis, mlx_stream s) int
mlx_scatter_single(Pointer<mlx_array> res, mlx_array a, mlx_array indices, mlx_array updates, int axis, mlx_stream s) int
mlx_segmented_mm(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_array segments, mlx_stream s) int
mlx_set_cache_limit(Pointer<Size> res, int limit) int
mlx_set_compile_mode(mlx_compile_mode_ mode) int
mlx_set_default_device(mlx_device dev) int
Set the default MLX device.
mlx_set_default_stream(mlx_stream stream) int
Set default stream.
mlx_set_error_handler(mlx_error_handler_func handler, Pointer<Void> data, Pointer<NativeFunction<Void Function(Pointer<Void>)>> dtor) → void
Set the error handler.
mlx_set_memory_limit(Pointer<Size> res, int limit) int
mlx_set_wired_limit(Pointer<Size> res, int limit) int
mlx_sigmoid(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_sign(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_sin(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_sinh(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_slice(Pointer<mlx_array> res, mlx_array a, Pointer<Int> start, int start_num, Pointer<Int> stop, int stop_num, Pointer<Int> strides, int strides_num, mlx_stream s) int
mlx_slice_dynamic(Pointer<mlx_array> res, mlx_array a, mlx_array start, Pointer<Int> axes, int axes_num, Pointer<Int> slice_size, int slice_size_num, mlx_stream s) int
mlx_slice_update(Pointer<mlx_array> res, mlx_array src, mlx_array update, Pointer<Int> start, int start_num, Pointer<Int> stop, int stop_num, Pointer<Int> strides, int strides_num, mlx_stream s) int
mlx_slice_update_dynamic(Pointer<mlx_array> res, mlx_array src, mlx_array update, mlx_array start, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_softmax(Pointer<mlx_array> res, mlx_array a, bool precise, mlx_stream s) int
mlx_softmax_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool precise, mlx_stream s) int
mlx_softmax_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool precise, mlx_stream s) int
mlx_sort(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_sort_axis(Pointer<mlx_array> res, mlx_array a, int axis, mlx_stream s) int
mlx_split(Pointer<mlx_vector_array> res, mlx_array a, int num_splits, int axis, mlx_stream s) int
mlx_split_sections(Pointer<mlx_vector_array> res, mlx_array a, Pointer<Int> indices, int indices_num, int axis, mlx_stream s) int
mlx_sqrt(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_square(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_squeeze(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_squeeze_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_squeeze_axis(Pointer<mlx_array> res, mlx_array a, int axis, mlx_stream s) int
mlx_stack(Pointer<mlx_array> res, mlx_vector_array arrays, mlx_stream s) int
mlx_stack_axis(Pointer<mlx_array> res, mlx_vector_array arrays, int axis, mlx_stream s) int
mlx_std(Pointer<mlx_array> res, mlx_array a, bool keepdims, int ddof, mlx_stream s) int
mlx_std_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, int ddof, mlx_stream s) int
mlx_std_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, int ddof, mlx_stream s) int
mlx_stop_gradient(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) bool
Check if streams are the same.
mlx_stream_free(mlx_stream stream) int
Free a stream.
mlx_stream_get_device(Pointer<mlx_device> dev, mlx_stream stream) int
Return the device of the stream.
mlx_stream_get_index(Pointer<Int> index, mlx_stream stream) int
Return the index of the stream.
mlx_stream_new() mlx_stream
Returns a new empty stream.
mlx_stream_new_device(mlx_device dev) mlx_stream
Returns a new stream on a device.
mlx_stream_set(Pointer<mlx_stream> stream, mlx_stream src) int
Set stream to provided src stream.
mlx_stream_tostring(Pointer<mlx_string> str, mlx_stream stream) int
Get stream description.
mlx_string_data(mlx_string str) Pointer<Char>
Returns a pointer to the string contents. The pointer is valid for the life duration of the string.
mlx_string_free(mlx_string str) int
Free string.
mlx_string_new() mlx_string
Returns a new empty string.
mlx_string_new_data(Pointer<Char> str) mlx_string
Returns a new string, copying contents from str, which must end with \0.
mlx_string_set(Pointer<mlx_string> str, mlx_string src) int
Set string to src string.
mlx_subtract(Pointer<mlx_array> res, mlx_array a, mlx_array b, mlx_stream s) int
mlx_sum(Pointer<mlx_array> res, mlx_array a, bool keepdims, mlx_stream s) int
mlx_sum_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, mlx_stream s) int
mlx_sum_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, mlx_stream s) int
mlx_swapaxes(Pointer<mlx_array> res, mlx_array a, int axis1, int axis2, mlx_stream s) int
mlx_synchronize(mlx_stream stream) int
Synchronize with the provided stream.
mlx_take(Pointer<mlx_array> res, mlx_array a, mlx_array indices, mlx_stream s) int
mlx_take_along_axis(Pointer<mlx_array> res, mlx_array a, mlx_array indices, int axis, mlx_stream s) int
mlx_take_axis(Pointer<mlx_array> res, mlx_array a, mlx_array indices, int axis, mlx_stream s) int
mlx_tan(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_tanh(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_tensordot(Pointer<mlx_array> res, mlx_array a, mlx_array b, Pointer<Int> axes_a, int axes_a_num, Pointer<Int> axes_b, int axes_b_num, mlx_stream s) int
mlx_tensordot_axis(Pointer<mlx_array> res, mlx_array a, mlx_array b, int axis, mlx_stream s) int
mlx_tile(Pointer<mlx_array> res, mlx_array arr, Pointer<Int> reps, int reps_num, mlx_stream s) int
mlx_to_fp8(Pointer<mlx_array> res, mlx_array x, mlx_stream s) int
mlx_topk(Pointer<mlx_array> res, mlx_array a, int k, mlx_stream s) int
mlx_topk_axis(Pointer<mlx_array> res, mlx_array a, int k, int axis, mlx_stream s) int
mlx_trace(Pointer<mlx_array> res, mlx_array a, int offset, int axis1, int axis2, mlx_dtype_ dtype, mlx_stream s) int
mlx_transpose(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int
mlx_transpose_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, mlx_stream s) int
mlx_tri(Pointer<mlx_array> res, int n, int m, int k, mlx_dtype_ type, mlx_stream s) int
mlx_tril(Pointer<mlx_array> res, mlx_array x, int k, mlx_stream s) int
mlx_triu(Pointer<mlx_array> res, mlx_array x, int k, mlx_stream s) int
mlx_unflatten(Pointer<mlx_array> res, mlx_array a, int axis, Pointer<Int> shape, int shape_num, mlx_stream s) int
mlx_value_and_grad(Pointer<mlx_closure_value_and_grad> res, mlx_closure fun, Pointer<Int> argnums, int argnums_num) int
mlx_var(Pointer<mlx_array> res, mlx_array a, bool keepdims, int ddof, mlx_stream s) int
mlx_var_axes(Pointer<mlx_array> res, mlx_array a, Pointer<Int> axes, int axes_num, bool keepdims, int ddof, mlx_stream s) int
mlx_var_axis(Pointer<mlx_array> res, mlx_array a, int axis, bool keepdims, int ddof, mlx_stream s) int
mlx_vector_array_append_data(mlx_vector_array vec, Pointer<mlx_array> data, int size) int
mlx_vector_array_append_value(mlx_vector_array vec, mlx_array val) int
mlx_vector_array_free(mlx_vector_array vec) int
mlx_vector_array_get(Pointer<mlx_array> res, mlx_vector_array vec, int idx) int
mlx_vector_array_new() mlx_vector_array
mlx_vector_array_new_data(Pointer<mlx_array> data, int size) mlx_vector_array
mlx_vector_array_new_value(mlx_array val) mlx_vector_array
mlx_vector_array_set(Pointer<mlx_vector_array> vec, mlx_vector_array src) int
mlx_vector_array_set_data(Pointer<mlx_vector_array> vec, Pointer<mlx_array> data, int size) int
mlx_vector_array_set_value(Pointer<mlx_vector_array> vec, mlx_array val) int
mlx_vector_array_size(mlx_vector_array vec) int
mlx_vector_int_append_data(mlx_vector_int vec, Pointer<Int> data, int size) int
mlx_vector_int_append_value(mlx_vector_int vec, int val) int
mlx_vector_int_free(mlx_vector_int vec) int
mlx_vector_int_get(Pointer<Int> res, mlx_vector_int vec, int idx) int
mlx_vector_int_new() mlx_vector_int
mlx_vector_int_new_data(Pointer<Int> data, int size) mlx_vector_int
mlx_vector_int_new_value(int val) mlx_vector_int
mlx_vector_int_set(Pointer<mlx_vector_int> vec, mlx_vector_int src) int
mlx_vector_int_set_data(Pointer<mlx_vector_int> vec, Pointer<Int> data, int size) int
mlx_vector_int_set_value(Pointer<mlx_vector_int> vec, int val) int
mlx_vector_int_size(mlx_vector_int vec) int
mlx_vector_string_append_data(mlx_vector_string vec, Pointer<Pointer<Char>> data, int size) int
mlx_vector_string_append_value(mlx_vector_string vec, Pointer<Char> val) int
mlx_vector_string_free(mlx_vector_string vec) int
mlx_vector_string_get(Pointer<Pointer<Char>> res, mlx_vector_string vec, int idx) int
mlx_vector_string_new() mlx_vector_string
mlx_vector_string_new_data(Pointer<Pointer<Char>> data, int size) mlx_vector_string
mlx_vector_string_new_value(Pointer<Char> val) mlx_vector_string
mlx_vector_string_set(Pointer<mlx_vector_string> vec, mlx_vector_string src) int
mlx_vector_string_set_data(Pointer<mlx_vector_string> vec, Pointer<Pointer<Char>> data, int size) int
mlx_vector_string_set_value(Pointer<mlx_vector_string> vec, Pointer<Char> val) int
mlx_vector_string_size(mlx_vector_string vec) int
mlx_vector_vector_array_append_data(mlx_vector_vector_array vec, Pointer<mlx_vector_array> data, int size) int
mlx_vector_vector_array_append_value(mlx_vector_vector_array vec, mlx_vector_array val) int
mlx_vector_vector_array_free(mlx_vector_vector_array vec) int
mlx_vector_vector_array_get(Pointer<mlx_vector_array> res, mlx_vector_vector_array vec, int idx) int
mlx_vector_vector_array_new() mlx_vector_vector_array
mlx_vector_vector_array_new_data(Pointer<mlx_vector_array> data, int size) mlx_vector_vector_array
mlx_vector_vector_array_new_value(mlx_vector_array val) mlx_vector_vector_array
mlx_vector_vector_array_set(Pointer<mlx_vector_vector_array> vec, mlx_vector_vector_array src) int
mlx_vector_vector_array_set_data(Pointer<mlx_vector_vector_array> vec, Pointer<mlx_vector_array> data, int size) int
mlx_vector_vector_array_set_value(Pointer<mlx_vector_vector_array> vec, mlx_vector_array val) int
mlx_vector_vector_array_size(mlx_vector_vector_array vec) int
mlx_version(Pointer<mlx_string> str_) int
mlx_view(Pointer<mlx_array> res, mlx_array a, mlx_dtype_ dtype, mlx_stream s) int
mlx_vjp(Pointer<mlx_vector_array> res_0, Pointer<mlx_vector_array> res_1, mlx_closure fun, mlx_vector_array primals, mlx_vector_array cotangents) int
mlx_where(Pointer<mlx_array> res, mlx_array condition, mlx_array x, mlx_array y, mlx_stream s) int
mlx_zeros(Pointer<mlx_array> res, Pointer<Int> shape, int shape_num, mlx_dtype_ dtype, mlx_stream s) int
mlx_zeros_like(Pointer<mlx_array> res, mlx_array a, mlx_stream s) int

Typedefs

Dartmlx_error_handler_funcFunction = void Function(Pointer<Char> msg, Pointer<Void> data)
mlx_array = mlx_array_
mlx_closure = mlx_closure_
mlx_closure_custom = mlx_closure_custom_
mlx_closure_custom_jvp = mlx_closure_custom_jvp_
mlx_closure_custom_vmap = mlx_closure_custom_vmap_
mlx_closure_kwargs = mlx_closure_kwargs_
mlx_closure_value_and_grad = mlx_closure_value_and_grad_
mlx_device = mlx_device_
mlx_device_info = mlx_device_info_
mlx_distributed_group = mlx_distributed_group_
mlx_error_handler_func = Pointer<NativeFunction<mlx_error_handler_funcFunction>>
\defgroup mlx_error Error management / /**@{
mlx_error_handler_funcFunction = Void Function(Pointer<Char> msg, Pointer<Void> data)
mlx_fast_cuda_kernel = mlx_fast_cuda_kernel_
mlx_fast_cuda_kernel_config = mlx_fast_cuda_kernel_config_
mlx_fast_metal_kernel = mlx_fast_metal_kernel_
mlx_fast_metal_kernel_config = mlx_fast_metal_kernel_config_
mlx_function_exporter = mlx_function_exporter_
mlx_imported_function = mlx_imported_function_
mlx_io_reader = mlx_io_reader_
mlx_io_vtable = mlx_io_vtable_
mlx_io_writer = mlx_io_writer_
mlx_map_string_to_array = mlx_map_string_to_array_
mlx_map_string_to_array_iterator = mlx_map_string_to_array_iterator_
mlx_map_string_to_string = mlx_map_string_to_string_
mlx_map_string_to_string_iterator = mlx_map_string_to_string_iterator_
mlx_optional_dtype = mlx_optional_dtype_
mlx_optional_float = mlx_optional_float_
mlx_optional_int = mlx_optional_int_
mlx_stream = mlx_stream_
mlx_string = mlx_string_
mlx_vector_array = mlx_vector_array_
mlx_vector_int = mlx_vector_int_
mlx_vector_string = mlx_vector_string_
mlx_vector_vector_array = mlx_vector_vector_array_