Skip to content

portable: accumulate in fp32 for Half/BFloat16 in grid_sampler_2d and sum#6

Closed
jgibson2 wants to merge 1 commit intomainfrom
jgibson/portable-fp16-precision
Closed

portable: accumulate in fp32 for Half/BFloat16 in grid_sampler_2d and sum#6
jgibson2 wants to merge 1 commit intomainfrom
jgibson/portable-fp16-precision

Conversation

@jgibson2
Copy link
Copy Markdown
Collaborator

Summary

Both grid_sampler_2d.out (bilinear) and sum.IntList_out in the portable kernels previously did all interior arithmetic in the input tensor's dtype. For fp16 inputs that's a material precision problem, not just FP rounding noise:

grid_sampler_2d.out (bilinear)

Interpolation weights are derived from subtractions like `(ix_se - ix)` and `(iy_se - iy)` where both operands are close integer values. In fp16, subtracting two close values of the form `N.xxx` and `N+1.xxx` with only ~10 bits of mantissa destroys most of the precision — classic catastrophic cancellation. The downstream weighted sum then accumulates further error in fp16.

sum.IntList_out

The fast path (innermost contiguous dim):

```cpp
CTYPE acc = 0;
for (int64_t j = 0; j < reduce_size; j++) {
acc += row[j]; // fp16 += fp16 over many elements
}
```

Over a reduction of more than ~100 fp16 values of similar magnitude, cumulative error grows to the same order as the reduction result itself. The slow path (via `MapReduceOverDimListPlan`) had the same issue because its `CTYPE_OUT` template parameter was used for the accumulator.

Fix

In each file: an `AccType` / `SumAccType` trait that maps `Half` and `BFloat16` to `float` and leaves every other dtype unchanged. The trait is used for the internal accumulator, intermediate coordinate / weight computation, and the single cast back to the output dtype at store time. Loads and stores remain in the tensor's dtype — only the inner arithmetic is promoted.

```cpp
template
using AccType = std::conditional_t<
std::is_same_v<CTYPE, executorch::aten::Half> ||
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
float,
CTYPE>;
```

Measured effects

On a unit test comparing "fp16 input → fp16 output" against "run in fp32 and cast at the end" for the shapes actually exercised by the polycam depth model:

Op Before After
`grid_sampler_2d` bilinear, fp16, random interior grid max_abs ≈ 0.10 max_abs = 0
`sum.IntList_out`, fp16, reduce over 256-element innermost max_abs ≈ 0.14 max_abs = 7.6e-6

Non-effects

  • fp32 / Int / any other dtype: byte-identical output. `AccType` is `T` for non-half types, so the generated code is unchanged.
  • Perf: overhead is a handful of fp16↔fp32 conversions per output element (cheap on any modern CPU with NEON / AVX). Not measurable at the op level.
  • No public API change. Signatures unchanged, no new flags, no new allocations.

Test plan

  • Builds cleanly on Android arm64 and host (Apple Clang 21).
  • Verified numerically via a standalone harness that runs each kernel with matched fp32 / fp16 inputs against an fp32 reference (run-in-fp32 then downcast). All shapes tested pass within fp16 ULP; fp32 paths bit-identical.
  • End-to-end model validation (to be done by the PR author on a trained polycam depth model — not strictly a blocker for landing since behavior on fp32 is unchanged).

Candidate for upstream at some point — this is a general correctness improvement, not Polycam-specific.

🤖 Generated with Claude Code

… sum

Both kernels previously performed all interior arithmetic in the input
tensor's dtype. For half-precision inputs that's a material precision bug:

  * grid_sampler_2d.out (bilinear): interpolation weights are derived from
    subtractions like `(ix_se - ix)` and `(iy_se - iy)` where both operands
    are close integer values. In fp16 that's catastrophic cancellation —
    the result has only a handful of significant bits. The weighted sum
    then further accumulates error in fp16.

  * sum.IntList_out: the fast path (innermost contiguous dim) uses a scalar
    accumulator of input dtype: `CTYPE acc = 0; acc += row[j]`. Over a
    reduction of >100 fp16 values, cumulative error reaches the same order
    of magnitude as the reduction itself. Observed empirically on a real
    model: sums of 256 fp16 values drift by up to ~0.14 absolute relative
    to the fp32 reference. The slow path (MapReduceOverDimListPlan) had
    the same issue via its CTYPE_OUT template parameter.

Fix in both files: introduce a simple `AccType<CTYPE>` / `SumAccType<CTYPE>`
trait that maps Half and BFloat16 to `float` and leaves all other types
unchanged. Use it for the internal accumulator, intermediate coordinate /
weight computation, and the single cast back to output dtype at store
time. Loads and stores remain in the tensor's dtype — only the inner
arithmetic is promoted.

Effects:

  * fp32 / Int / other dtypes: byte-identical output (AccType is a no-op).
  * fp16 / BFloat16: substantially tighter agreement with fp32 reference.
    On a unit test exercising the actually-hot shapes in our depth model,
    `max_abs` between "fp16 input → fp16 output" and "run in fp32 and cast
    at the end" drops from ~0.14 to 7.6e-6 for sum and from ~0.1 to 0
    for grid_sampler_2d bilinear.
  * Perf: the overhead is a handful of fp16↔fp32 conversions per output
    element. Not measurable at the op level — still well within the
    scalar-portable-kernel cost envelope for Half inputs.

No public API change. No behavioral change for fp32 workloads.
@jgibson2
Copy link
Copy Markdown
Collaborator Author

Superseded — the grid_sampler portion is now upstream at pytorch#19117, and the sum portion is dropped (not needed for the polycam use case once the NEON sum kernel is in place).

@jgibson2 jgibson2 closed this Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant