Commit 60ffe19
authored
portable: accumulate in fp32 for Half/BFloat16 in grid_sampler_2d bilinear (#19117)
## Summary
The bilinear grid_sampler_2d portable kernel computes interpolation
weights via subtractions like `(ix_se - ix)` where both operands are
close integer-valued coordinates in pixel space. In fp16 (10 bits of
mantissa) that's classic catastrophic cancellation — the result has only
a handful of significant bits. The downstream weighted-sum accumulation
then loses further precision.
Measured on a unit test exercising interior grid points with fp16
inputs, the kernel drifts by ~0.1 absolute from an fp32 reference.
That's visible as incorrect depth / flow output near non-integer sample
points, which is most of them.
## Fix
An `AccType<CTYPE>` trait mapping `Half` and `BFloat16` to `float`,
leaving every other dtype unchanged. Used for intermediate coordinate,
weight computation, and `out_val` accumulation. Loads cast `CTYPE ->
ACC`; the final store casts `ACC -> CTYPE` once. Only internal math is
promoted; memory layout / public API / tensor dtypes are unchanged.
```cpp
template <typename CTYPE>
using AccType = std::conditional_t<
std::is_same_v<CTYPE, executorch::aten::Half> ||
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
float,
CTYPE>;
```
## Effects
- **fp32 / Int / any non-half dtype**: `AccType<T>` is `T`, so the
generated code is byte-identical. No behavior change.
- **Half / BFloat16**: `max_abs` vs an fp32 reference drops from **~0.1
to 0** on the shapes I tested (N=1..2, C=7..64, H/W up to 96, both
`align_corners` values).
- **Perf**: a handful of fp16↔fp32 conversions per output element. Not
measurable at op level; well within the portable kernel's scalar cost
envelope.
## Scope
Only touches the bilinear interpolation path. The nearest-mode path
doesn't do weighted-sum accumulation and doesn't have the cancellation
issue — left alone in this change.
## Test plan
- [x] Builds clean for Android arm64 and host (Apple Clang 21).
- [x] Verified numerically via a standalone harness that runs the kernel
with matched fp32 / fp16 inputs and compares against an
fp32-then-downcast reference. All shapes pass within a single fp16 ULP
(or are bit-exact). fp32 tests remain bit-identical to the pre-change
kernel.
- [x] Existing `kernels/test/op_grid_sampler_2d_test.cpp` unit tests
continue to pass (both fp32 shapes that were previously tested, and the
fp16 path I'm specifically fixing).
Happy to add an fp16-specific test case to `op_grid_sampler_2d_test.cpp`
if useful for CI coverage here — just let me know the preferred
approach.
cc @larryliu0820 @manuelcandales1 parent 56da964 commit 60ffe19
1 file changed
Lines changed: 65 additions & 34 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
| 14 | + | |
13 | 15 | | |
14 | 16 | | |
15 | 17 | | |
| |||
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
22 | 40 | | |
23 | 41 | | |
24 | 42 | | |
25 | 43 | | |
26 | 44 | | |
27 | 45 | | |
28 | 46 | | |
| 47 | + | |
29 | 48 | | |
30 | 49 | | |
31 | 50 | | |
| |||
59 | 78 | | |
60 | 79 | | |
61 | 80 | | |
62 | | - | |
63 | | - | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
64 | 84 | | |
65 | | - | |
66 | | - | |
| 85 | + | |
| 86 | + | |
67 | 87 | | |
68 | | - | |
| 88 | + | |
69 | 89 | | |
70 | 90 | | |
71 | 91 | | |
| |||
78 | 98 | | |
79 | 99 | | |
80 | 100 | | |
81 | | - | |
82 | | - | |
83 | | - | |
84 | | - | |
85 | | - | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
86 | 108 | | |
87 | | - | |
88 | | - | |
| 109 | + | |
| 110 | + | |
89 | 111 | | |
90 | 112 | | |
91 | 113 | | |
92 | 114 | | |
93 | 115 | | |
94 | | - | |
95 | | - | |
96 | | - | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
97 | 120 | | |
98 | 121 | | |
99 | 122 | | |
100 | | - | |
101 | | - | |
102 | | - | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
103 | 127 | | |
104 | 128 | | |
105 | 129 | | |
106 | | - | |
107 | | - | |
108 | | - | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
109 | 134 | | |
110 | 135 | | |
111 | 136 | | |
112 | | - | |
113 | | - | |
114 | | - | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
115 | 141 | | |
116 | 142 | | |
117 | 143 | | |
| |||
126 | 152 | | |
127 | 153 | | |
128 | 154 | | |
129 | | - | |
130 | | - | |
131 | | - | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
132 | 160 | | |
133 | | - | |
| 161 | + | |
| 162 | + | |
134 | 163 | | |
135 | | - | |
| 164 | + | |
136 | 165 | | |
137 | | - | |
| 166 | + | |
| 167 | + | |
138 | 168 | | |
139 | | - | |
| 169 | + | |
140 | 170 | | |
141 | | - | |
| 171 | + | |
| 172 | + | |
142 | 173 | | |
143 | | - | |
| 174 | + | |
144 | 175 | | |
145 | 176 | | |
146 | 177 | | |
147 | 178 | | |
148 | 179 | | |
149 | 180 | | |
150 | | - | |
| 181 | + | |
151 | 182 | | |
152 | 183 | | |
153 | 184 | | |
| |||
0 commit comments