Skip to content

Commit 222711e

Browse files
authored
Add safe_numel()
Differential Revision: D102070375 Pull Request resolved: #19074
1 parent 0a43e2f commit 222711e

4 files changed

Lines changed: 75 additions & 1 deletion

File tree

runtime/core/exec_aten/exec_aten.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
#pragma once
1010

11+
#include <executorch/runtime/core/error.h> // @manual
12+
#include <executorch/runtime/core/result.h> // @manual
1113
#include <executorch/runtime/core/tensor_shape_dynamism.h> // @manual
14+
#include <executorch/runtime/platform/assert.h> // @manual
1215
#include <executorch/runtime/platform/compiler.h>
1316
#ifdef USE_ATEN_LIB
1417
#include <ATen/Tensor.h> // @manual
@@ -28,6 +31,7 @@
2831
#include <c10/util/quint2x4.h> // @manual
2932
#include <c10/util/quint4x2.h> // @manual
3033
#include <c10/util/quint8.h> // @manual
34+
#include <c10/util/safe_numerics.h> // @manual
3135
#include <c10/util/string_view.h> // @manual
3236
#include <torch/torch.h>
3337
#else // use executor
@@ -110,6 +114,32 @@ inline ssize_t compute_numel(const SizesType* sizes, ssize_t dim) {
110114
c10::multiply_integers(c10::ArrayRef<SizesType>(sizes, dim)));
111115
}
112116

117+
inline ::executorch::runtime::Result<ssize_t> safe_numel(
118+
const SizesType* sizes,
119+
ssize_t dim) {
120+
ET_CHECK_OR_RETURN_ERROR(
121+
dim == 0 || sizes != nullptr,
122+
InvalidArgument,
123+
"Sizes must be provided for non-scalar tensors");
124+
ssize_t numel = 1;
125+
for (ssize_t i = 0; i < dim; i++) {
126+
ET_CHECK_OR_RETURN_ERROR(
127+
sizes[i] >= 0,
128+
InvalidArgument,
129+
"Size must be non-negative, got %zd at dimension %zd",
130+
static_cast<ssize_t>(sizes[i]),
131+
i);
132+
ssize_t next_numel;
133+
ET_CHECK_OR_RETURN_ERROR(
134+
!c10::mul_overflows(numel, static_cast<ssize_t>(sizes[i]), &next_numel),
135+
InvalidArgument,
136+
"Overflow computing numel at dimension %zd",
137+
i);
138+
numel = next_numel;
139+
}
140+
return numel;
141+
}
142+
113143
#undef ET_PRI_TENSOR_SIZE
114144
#define ET_PRI_TENSOR_SIZE PRId64
115145

@@ -158,6 +188,7 @@ using OptionalArrayRef =
158188
using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
159189

160190
using torch::executor::compute_numel;
191+
using torch::executor::safe_numel;
161192

162193
#endif // Use ExecuTorch types
163194

runtime/core/exec_aten/targets.bzl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def define_common_targets():
1616
exported_headers = ["exec_aten.h"],
1717
exported_preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
1818
visibility = ["PUBLIC"],
19-
exported_deps = ["//executorch/runtime/core:tensor_shape_dynamism"] + ([] if aten_mode else ["//executorch/runtime/core/portable_type:portable_type"]),
19+
exported_deps = [
20+
"//executorch/runtime/core:core",
21+
"//executorch/runtime/core:tensor_shape_dynamism",
22+
] + ([] if aten_mode else ["//executorch/runtime/core/portable_type:portable_type"]),
2023
exported_external_deps = ["libtorch"] if aten_mode else [],
2124
)

runtime/core/portable_type/tensor_impl.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <cstdint>
1313

1414
#include <c10/util/irange.h>
15+
#include <c10/util/safe_numerics.h>
1516

1617
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
1718
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -43,6 +44,32 @@ ssize_t compute_numel(const TensorImpl::SizesType* sizes, ssize_t dim) {
4344
return numel;
4445
}
4546

47+
::executorch::runtime::Result<ssize_t> safe_numel(
48+
const TensorImpl::SizesType* sizes,
49+
ssize_t dim) {
50+
ET_CHECK_OR_RETURN_ERROR(
51+
dim == 0 || sizes != nullptr,
52+
InvalidArgument,
53+
"Sizes must be provided for non-scalar tensors");
54+
ssize_t numel = 1;
55+
for (const auto i : c10::irange(dim)) {
56+
ET_CHECK_OR_RETURN_ERROR(
57+
sizes[i] >= 0,
58+
InvalidArgument,
59+
"Size must be non-negative, got %zd at dimension %zd",
60+
static_cast<ssize_t>(sizes[i]),
61+
i);
62+
ssize_t next_numel;
63+
ET_CHECK_OR_RETURN_ERROR(
64+
!c10::mul_overflows(numel, static_cast<ssize_t>(sizes[i]), &next_numel),
65+
InvalidArgument,
66+
"Overflow computing numel at dimension %zd",
67+
i);
68+
numel = next_numel;
69+
}
70+
return numel;
71+
}
72+
4673
TensorImpl::TensorImpl(
4774
ScalarType type,
4875
ssize_t dim,

runtime/core/portable_type/tensor_impl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
#include <executorch/runtime/core/error.h>
1313
#include <executorch/runtime/core/portable_type/device.h>
1414
#include <executorch/runtime/core/portable_type/scalar_type.h>
15+
#include <executorch/runtime/core/result.h>
1516
#include <executorch/runtime/core/tensor_shape_dynamism.h>
17+
#include <executorch/runtime/platform/compiler.h>
1618

1719
// Forward declaration of a helper that provides access to internal resizing
1820
// methods of TensorImpl. Real definition is in
@@ -293,6 +295,16 @@ ssize_t compute_numel(
293295
const ::executorch::runtime::etensor::TensorImpl::SizesType* sizes,
294296
ssize_t dim);
295297

298+
/**
299+
* Compute the number of elements based on the sizes of a tensor.
300+
* Returns Error::InvalidArgument if any intermediate multiplication would
301+
* overflow ssize_t, or if a size is negative. Prefer this over compute_numel()
302+
* for paths that can propagate an Error upward.
303+
*/
304+
::executorch::runtime::Result<ssize_t> safe_numel(
305+
const ::executorch::runtime::etensor::TensorImpl::SizesType* sizes,
306+
ssize_t dim);
307+
296308
/// Appropriate format specifier for the result of calling
297309
/// size(). Must be used instead of using zd directly to support ATen
298310
/// mode.
@@ -322,6 +334,7 @@ namespace executor {
322334
// TODO(T197294990): Remove these deprecated aliases once all users have moved
323335
// to the new `::executorch` namespaces.
324336
using ::executorch::runtime::etensor::compute_numel;
337+
using ::executorch::runtime::etensor::safe_numel;
325338
using ::executorch::runtime::etensor::TensorImpl;
326339
} // namespace executor
327340
} // namespace torch

0 commit comments

Comments
 (0)