|
8 | 8 |
|
9 | 9 | #pragma once |
10 | 10 |
|
| 11 | +#include <executorch/runtime/core/error.h> // @manual |
| 12 | +#include <executorch/runtime/core/result.h> // @manual |
11 | 13 | #include <executorch/runtime/core/tensor_shape_dynamism.h> // @manual |
| 14 | +#include <executorch/runtime/platform/assert.h> // @manual |
12 | 15 | #include <executorch/runtime/platform/compiler.h> |
13 | 16 | #ifdef USE_ATEN_LIB |
14 | 17 | #include <ATen/Tensor.h> // @manual |
|
28 | 31 | #include <c10/util/quint2x4.h> // @manual |
29 | 32 | #include <c10/util/quint4x2.h> // @manual |
30 | 33 | #include <c10/util/quint8.h> // @manual |
| 34 | +#include <c10/util/safe_numerics.h> // @manual |
31 | 35 | #include <c10/util/string_view.h> // @manual |
32 | 36 | #include <torch/torch.h> |
33 | 37 | #else // use executor |
@@ -110,6 +114,32 @@ inline ssize_t compute_numel(const SizesType* sizes, ssize_t dim) { |
110 | 114 | c10::multiply_integers(c10::ArrayRef<SizesType>(sizes, dim))); |
111 | 115 | } |
112 | 116 |
|
| 117 | +inline ::executorch::runtime::Result<ssize_t> safe_numel( |
| 118 | + const SizesType* sizes, |
| 119 | + ssize_t dim) { |
| 120 | + ET_CHECK_OR_RETURN_ERROR( |
| 121 | + dim == 0 || sizes != nullptr, |
| 122 | + InvalidArgument, |
| 123 | + "Sizes must be provided for non-scalar tensors"); |
| 124 | + ssize_t numel = 1; |
| 125 | + for (ssize_t i = 0; i < dim; i++) { |
| 126 | + ET_CHECK_OR_RETURN_ERROR( |
| 127 | + sizes[i] >= 0, |
| 128 | + InvalidArgument, |
| 129 | + "Size must be non-negative, got %zd at dimension %zd", |
| 130 | + static_cast<ssize_t>(sizes[i]), |
| 131 | + i); |
| 132 | + ssize_t next_numel; |
| 133 | + ET_CHECK_OR_RETURN_ERROR( |
| 134 | + !c10::mul_overflows(numel, static_cast<ssize_t>(sizes[i]), &next_numel), |
| 135 | + InvalidArgument, |
| 136 | + "Overflow computing numel at dimension %zd", |
| 137 | + i); |
| 138 | + numel = next_numel; |
| 139 | + } |
| 140 | + return numel; |
| 141 | +} |
| 142 | + |
113 | 143 | #undef ET_PRI_TENSOR_SIZE |
114 | 144 | #define ET_PRI_TENSOR_SIZE PRId64 |
115 | 145 |
|
@@ -158,6 +188,7 @@ using OptionalArrayRef = |
158 | 188 | using OptionalIntArrayRef = OptionalArrayRef<int64_t>; |
159 | 189 |
|
160 | 190 | using torch::executor::compute_numel; |
| 191 | +using torch::executor::safe_numel; |
161 | 192 |
|
162 | 193 | #endif // Use ExecuTorch types |
163 | 194 |
|
|
0 commit comments