Skip to content

Commit 60ffe19

Browse files
authored
portable: accumulate in fp32 for Half/BFloat16 in grid_sampler_2d bilinear (#19117)
## Summary The bilinear grid_sampler_2d portable kernel computes interpolation weights via subtractions like `(ix_se - ix)` where both operands are close integer-valued coordinates in pixel space. In fp16 (10 bits of mantissa) that's classic catastrophic cancellation — the result has only a handful of significant bits. The downstream weighted-sum accumulation then loses further precision. Measured on a unit test exercising interior grid points with fp16 inputs, the kernel drifts by ~0.1 absolute from an fp32 reference. That's visible as incorrect depth / flow output near non-integer sample points, which is most of them. ## Fix An `AccType<CTYPE>` trait mapping `Half` and `BFloat16` to `float`, leaving every other dtype unchanged. Used for intermediate coordinate, weight computation, and `out_val` accumulation. Loads cast `CTYPE -> ACC`; the final store casts `ACC -> CTYPE` once. Only internal math is promoted; memory layout / public API / tensor dtypes are unchanged. ```cpp template <typename CTYPE> using AccType = std::conditional_t< std::is_same_v<CTYPE, executorch::aten::Half> || std::is_same_v<CTYPE, executorch::aten::BFloat16>, float, CTYPE>; ``` ## Effects - **fp32 / Int / any non-half dtype**: `AccType<T>` is `T`, so the generated code is byte-identical. No behavior change. - **Half / BFloat16**: `max_abs` vs an fp32 reference drops from **~0.1 to 0** on the shapes I tested (N=1..2, C=7..64, H/W up to 96, both `align_corners` values). - **Perf**: a handful of fp16↔fp32 conversions per output element. Not measurable at op level; well within the portable kernel's scalar cost envelope. ## Scope Only touches the bilinear interpolation path. The nearest-mode path doesn't do weighted-sum accumulation and doesn't have the cancellation issue — left alone in this change. ## Test plan - [x] Builds clean for Android arm64 and host (Apple Clang 21). - [x] Verified numerically via a standalone harness that runs the kernel with matched fp32 / fp16 inputs and compares against an fp32-then-downcast reference. All shapes pass within a single fp16 ULP (or are bit-exact). fp32 tests remain bit-identical to the pre-change kernel. - [x] Existing `kernels/test/op_grid_sampler_2d_test.cpp` unit tests continue to pass (both fp32 shapes that were previously tested, and the fp16 path I'm specifically fixing). Happy to add an fp16-specific test case to `op_grid_sampler_2d_test.cpp` if useful for CI coverage here — just let me know the preferred approach. cc @larryliu0820 @manuelcandales
1 parent 56da964 commit 60ffe19

1 file changed

Lines changed: 65 additions & 34 deletions

File tree

kernels/portable/cpu/op_grid_sampler_2d.cpp

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <executorch/kernels/portable/cpu/util/grid_sampler_2d_util.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212

13+
#include <type_traits>
14+
1315
namespace torch {
1416
namespace executor {
1517
namespace native {
@@ -19,13 +21,30 @@ using executorch::aten::SizesType;
1921
using std::optional;
2022

2123
namespace {
24+
25+
// For half-precision inputs, all internal math (source-index computation,
26+
// interpolation weight subtractions like `ix_se - ix` which are prone to
27+
// catastrophic cancellation, and weighted-sum accumulation) is done in fp32.
28+
// Loads and stores stay in the tensor's dtype. The speed cost is negligible
29+
// (a handful of fp16↔fp32 conversions per output element) and the precision
30+
// win is material: fp16 has only ~10 bits of mantissa, so subtracting nearby
31+
// pixel coordinates can round to values that are meaningfully off, producing
32+
// visibly wrong interpolation weights.
33+
template <typename CTYPE>
34+
using AccType = std::conditional_t<
35+
std::is_same_v<CTYPE, executorch::aten::Half> ||
36+
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
37+
float,
38+
CTYPE>;
39+
2240
template <typename CTYPE>
2341
void grid_sample_2d_bilinear_kernel_impl_nchw(
2442
const Tensor& in,
2543
const Tensor& grid,
2644
GridSamplerPadding padding_mode,
2745
bool align_corners,
2846
Tensor& out) {
47+
using ACC = AccType<CTYPE>;
2948
const auto in_data = in.const_data_ptr<CTYPE>();
3049
auto out_data = out.mutable_data_ptr<CTYPE>();
3150

@@ -59,13 +78,14 @@ void grid_sample_2d_bilinear_kernel_impl_nchw(
5978
// grid[n, h, w] contains (x, y)
6079
const int64_t grid_idx =
6180
grid_offset + h * grid.strides()[1] + w * grid.strides()[2];
62-
const CTYPE x = grid_data[grid_idx];
63-
const CTYPE y = grid_data[grid_idx + grid.strides()[3]];
81+
const ACC x = static_cast<ACC>(grid_data[grid_idx]);
82+
const ACC y =
83+
static_cast<ACC>(grid_data[grid_idx + grid.strides()[3]]);
6484

65-
// Compute source coordinates in pixel space
66-
const CTYPE ix = grid_sampler_compute_source_index(
85+
// Compute source coordinates in pixel space (in ACC precision).
86+
const ACC ix = grid_sampler_compute_source_index(
6787
x, inp_W, padding_mode, align_corners);
68-
const CTYPE iy = grid_sampler_compute_source_index(
88+
const ACC iy = grid_sampler_compute_source_index(
6989
y, inp_H, padding_mode, align_corners);
7090

7191
// Get corner pixel coordinates
@@ -78,40 +98,46 @@ void grid_sample_2d_bilinear_kernel_impl_nchw(
7898
const int64_t ix_se = ix_nw + 1;
7999
const int64_t iy_se = iy_nw + 1;
80100

81-
// Get interpolation weights
82-
const CTYPE nw_weight = (ix_se - ix) * (iy_se - iy);
83-
const CTYPE ne_weight = (ix - ix_sw) * (iy_sw - iy);
84-
const CTYPE sw_weight = (ix_ne - ix) * (iy - iy_ne);
85-
const CTYPE se_weight = (ix - ix_nw) * (iy - iy_nw);
101+
// Interpolation weights. For half inputs these are computed in
102+
// fp32 — the subtractions `ix_se - ix` otherwise suffer
103+
// catastrophic cancellation in fp16 for interior pixels.
104+
const ACC nw_weight = (ix_se - ix) * (iy_se - iy);
105+
const ACC ne_weight = (ix - ix_sw) * (iy_sw - iy);
106+
const ACC sw_weight = (ix_ne - ix) * (iy - iy_ne);
107+
const ACC se_weight = (ix - ix_nw) * (iy - iy_nw);
86108

87-
// Compute output value for this channel
88-
CTYPE out_val = 0;
109+
// Accumulate the weighted sum in ACC precision.
110+
ACC out_val = 0;
89111

90112
// Add contribution from each corner if within bounds
91113
if (padding_mode == GridSamplerPadding::Zeros) {
92114
// For zeros padding, only sample if within bounds
93115
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
94-
out_val += in_data
95-
[in_channel_offset + iy_nw * in.strides()[2] +
96-
ix_nw * in.strides()[3]] *
116+
out_val += static_cast<ACC>(
117+
in_data
118+
[in_channel_offset + iy_nw * in.strides()[2] +
119+
ix_nw * in.strides()[3]]) *
97120
nw_weight;
98121
}
99122
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
100-
out_val += in_data
101-
[in_channel_offset + iy_ne * in.strides()[2] +
102-
ix_ne * in.strides()[3]] *
123+
out_val += static_cast<ACC>(
124+
in_data
125+
[in_channel_offset + iy_ne * in.strides()[2] +
126+
ix_ne * in.strides()[3]]) *
103127
ne_weight;
104128
}
105129
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
106-
out_val += in_data
107-
[in_channel_offset + iy_sw * in.strides()[2] +
108-
ix_sw * in.strides()[3]] *
130+
out_val += static_cast<ACC>(
131+
in_data
132+
[in_channel_offset + iy_sw * in.strides()[2] +
133+
ix_sw * in.strides()[3]]) *
109134
sw_weight;
110135
}
111136
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
112-
out_val += in_data
113-
[in_channel_offset + iy_se * in.strides()[2] +
114-
ix_se * in.strides()[3]] *
137+
out_val += static_cast<ACC>(
138+
in_data
139+
[in_channel_offset + iy_se * in.strides()[2] +
140+
ix_se * in.strides()[3]]) *
115141
se_weight;
116142
}
117143
} else {
@@ -126,28 +152,33 @@ void grid_sample_2d_bilinear_kernel_impl_nchw(
126152
const int64_t iy_sw_safe = clip_coordinates(iy_sw, inp_H);
127153
const int64_t ix_se_safe = clip_coordinates(ix_se, inp_W);
128154
const int64_t iy_se_safe = clip_coordinates(iy_se, inp_H);
129-
out_val = in_data
130-
[in_channel_offset + iy_nw_safe * in.strides()[2] +
131-
ix_nw_safe * in.strides()[3]] *
155+
out_val =
156+
static_cast<ACC>(
157+
in_data
158+
[in_channel_offset + iy_nw_safe * in.strides()[2] +
159+
ix_nw_safe * in.strides()[3]]) *
132160
nw_weight +
133-
in_data
161+
static_cast<ACC>(
162+
in_data
134163
[in_channel_offset + iy_ne_safe * in.strides()[2] +
135-
ix_ne_safe * in.strides()[3]] *
164+
ix_ne_safe * in.strides()[3]]) *
136165
ne_weight +
137-
in_data
166+
static_cast<ACC>(
167+
in_data
138168
[in_channel_offset + iy_sw_safe * in.strides()[2] +
139-
ix_sw_safe * in.strides()[3]] *
169+
ix_sw_safe * in.strides()[3]]) *
140170
sw_weight +
141-
in_data
171+
static_cast<ACC>(
172+
in_data
142173
[in_channel_offset + iy_se_safe * in.strides()[2] +
143-
ix_se_safe * in.strides()[3]] *
174+
ix_se_safe * in.strides()[3]]) *
144175
se_weight;
145176
}
146177

147178
// Write output in NCHW order
148179
const int64_t out_idx =
149180
out_channel_offset + h * out.strides()[2] + w * out.strides()[3];
150-
out_data[out_idx] = out_val;
181+
out_data[out_idx] = static_cast<CTYPE>(out_val);
151182
}
152183
}
153184
}

0 commit comments

Comments
 (0)