arange function

Tensor arange(
  1. num start,
  2. num end,
  3. num step, {
  4. int dtype = float32,
  5. bool requiresGrad = false,
  6. Device? device_used,
})

Implementation

Tensor arange(num start, num end, num step,
    {int dtype = float32, bool requiresGrad = false, Device? device_used}) {
  device_used ??= device("cpu");

  if (dtype == float32) {
    Scalar startScalar = float32_to_scalar(start.toDouble());
    Scalar endScalar = float32_to_scalar(start.toDouble());
    Scalar stepScalar = float32_to_scalar(start.toDouble());
    final resultTensorPtr = Tensor_arange(
        startScalar.scalarPtr,
        endScalar.scalarPtr,
        stepScalar.scalarPtr,
        dtype,
        device_used.device_type,
        device_used.device_index,
        requiresGrad);
    final errorMsg = _get_and_reset_last_err();

    // 检查是否有错误信息,如果有,则抛出异常
    if (errorMsg != nullptr) {
      final errorString = errorMsg.cast<Utf8>().toDartString();

      throw Exception(errorString);
    }

    if(resultTensorPtr!=nullptr){final tensor = Tensor._internal(resultTensorPtr);return tensor;}else{throw Exception("null pointer");}


  } else if (dtype == float64) {
    Scalar startScalar = float64_to_scalar(start.toDouble());
    Scalar endScalar = float64_to_scalar(start.toDouble());
    Scalar stepScalar = float64_to_scalar(start.toDouble());
    final resultTensorPtr = Tensor_arange(
        startScalar.scalarPtr,
        endScalar.scalarPtr,
        stepScalar.scalarPtr,
        dtype,
        device_used.device_type,
        device_used.device_index,
        requiresGrad);
    final errorMsg = _get_and_reset_last_err();

    // 检查是否有错误信息,如果有,则抛出异常
    if (errorMsg != nullptr) {
      final errorString = errorMsg.cast<Utf8>().toDartString();

      throw Exception(errorString);
    }

    if(resultTensorPtr!=nullptr){final tensor = Tensor._internal(resultTensorPtr);return tensor;}else{throw Exception("null pointer");}


  } else if (dtype == int32) {
    Scalar startScalar = int32_to_scalar(start.toInt());
    Scalar endScalar = int32_to_scalar(start.toInt());
    Scalar stepScalar = int32_to_scalar(start.toInt());
    final resultTensorPtr = Tensor_arange(
        startScalar.scalarPtr,
        endScalar.scalarPtr,
        stepScalar.scalarPtr,
        dtype,
        device_used.device_type,
        device_used.device_index,
        requiresGrad);
    final errorMsg = _get_and_reset_last_err();

    // 检查是否有错误信息,如果有,则抛出异常
    if (errorMsg != nullptr) {
      final errorString = errorMsg.cast<Utf8>().toDartString();

      throw Exception(errorString);
    }

    if(resultTensorPtr!=nullptr){final tensor = Tensor._internal(resultTensorPtr);return tensor;}else{throw Exception("null pointer");}


  } else {
    throw Exception("wrong type");
  }
}