portable: accumulate in fp32 for Half/BFloat16 in grid_sampler_2d and sum#6
Closed
portable: accumulate in fp32 for Half/BFloat16 in grid_sampler_2d and sum#6
Conversation
… sum
Both kernels previously performed all interior arithmetic in the input
tensor's dtype. For half-precision inputs that's a material precision bug:
* grid_sampler_2d.out (bilinear): interpolation weights are derived from
subtractions like `(ix_se - ix)` and `(iy_se - iy)` where both operands
are close integer values. In fp16 that's catastrophic cancellation —
the result has only a handful of significant bits. The weighted sum
then further accumulates error in fp16.
* sum.IntList_out: the fast path (innermost contiguous dim) uses a scalar
accumulator of input dtype: `CTYPE acc = 0; acc += row[j]`. Over a
reduction of >100 fp16 values, cumulative error reaches the same order
of magnitude as the reduction itself. Observed empirically on a real
model: sums of 256 fp16 values drift by up to ~0.14 absolute relative
to the fp32 reference. The slow path (MapReduceOverDimListPlan) had
the same issue via its CTYPE_OUT template parameter.
Fix in both files: introduce a simple `AccType<CTYPE>` / `SumAccType<CTYPE>`
trait that maps Half and BFloat16 to `float` and leaves all other types
unchanged. Use it for the internal accumulator, intermediate coordinate /
weight computation, and the single cast back to output dtype at store
time. Loads and stores remain in the tensor's dtype — only the inner
arithmetic is promoted.
Effects:
* fp32 / Int / other dtypes: byte-identical output (AccType is a no-op).
* fp16 / BFloat16: substantially tighter agreement with fp32 reference.
On a unit test exercising the actually-hot shapes in our depth model,
`max_abs` between "fp16 input → fp16 output" and "run in fp32 and cast
at the end" drops from ~0.14 to 7.6e-6 for sum and from ~0.1 to 0
for grid_sampler_2d bilinear.
* Perf: the overhead is a handful of fp16↔fp32 conversions per output
element. Not measurable at the op level — still well within the
scalar-portable-kernel cost envelope for Half inputs.
No public API change. No behavioral change for fp32 workloads.
Collaborator
Author
|
Superseded — the grid_sampler portion is now upstream at pytorch#19117, and the sum portion is dropped (not needed for the polycam use case once the NEON sum kernel is in place). |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Both
grid_sampler_2d.out(bilinear) andsum.IntList_outin the portable kernels previously did all interior arithmetic in the input tensor's dtype. For fp16 inputs that's a material precision problem, not just FP rounding noise:grid_sampler_2d.out(bilinear)Interpolation weights are derived from subtractions like `(ix_se - ix)` and `(iy_se - iy)` where both operands are close integer values. In fp16, subtracting two close values of the form `N.xxx` and `N+1.xxx` with only ~10 bits of mantissa destroys most of the precision — classic catastrophic cancellation. The downstream weighted sum then accumulates further error in fp16.
sum.IntList_outThe fast path (innermost contiguous dim):
```cpp
CTYPE acc = 0;
for (int64_t j = 0; j < reduce_size; j++) {
acc += row[j]; // fp16 += fp16 over many elements
}
```
Over a reduction of more than ~100 fp16 values of similar magnitude, cumulative error grows to the same order as the reduction result itself. The slow path (via `MapReduceOverDimListPlan`) had the same issue because its `CTYPE_OUT` template parameter was used for the accumulator.
Fix
In each file: an `AccType` / `SumAccType` trait that maps `Half` and `BFloat16` to `float` and leaves every other dtype unchanged. The trait is used for the internal accumulator, intermediate coordinate / weight computation, and the single cast back to the output dtype at store time. Loads and stores remain in the tensor's dtype — only the inner arithmetic is promoted.
```cpp
template
using AccType = std::conditional_t<
std::is_same_v<CTYPE, executorch::aten::Half> ||
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
float,
CTYPE>;
```
Measured effects
On a unit test comparing "fp16 input → fp16 output" against "run in fp32 and cast at the end" for the shapes actually exercised by the polycam depth model:
Non-effects
Test plan
Candidate for upstream at some point — this is a general correctness improvement, not Polycam-specific.
🤖 Generated with Claude Code