Skip to content

[ET-VK] Add apply_rotary_emb_interleaved fused operator#19115

Open
SS-JIA wants to merge 3 commits intogh/SS-JIA/524/basefrom
gh/SS-JIA/524/head
Open

[ET-VK] Add apply_rotary_emb_interleaved fused operator#19115
SS-JIA wants to merge 3 commits intogh/SS-JIA/524/basefrom
gh/SS-JIA/524/head

Conversation

@SS-JIA
Copy link
Copy Markdown
Contributor

@SS-JIA SS-JIA commented Apr 24, 2026

Stack from ghstack (oldest at bottom):

Introduces et_vk.apply_rotary_emb_interleaved, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (view/unbind/stack/view -> lowers to slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy) with a single GPU dispatch.

Math: On pair-interleaved inputs where element 2k is real and 2k+1 is imag, for each k in [0, C/2):

out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k]
out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]

Why a new op instead of reusing et_vk.apply_rotary_emb: The existing LLM-oriented operator takes (xq, xk) pairs with separate freqs_cos / freqs_sin tensors and 4D (B, S, H, D) shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D (B, N, C) tensor through RoPE (no heads dim) with a fused [N, C/2, 2] freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads.

Op contract: apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor where x is [B, N, C] and freqs_cis is any rank with N*C elements and the cos/sin values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is [1, N, C/2, 2]; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches.

Shader: Single-dispatch kernel, one texel out per thread. Each thread reads one x texel (2 real/imag pairs) and the corresponding freqs_cis entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. x and output support buffer + texture3d; freqs_cis is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the FP_T dtype iterator in the YAML.

Op registration: Meta kernel returns torch.empty_like(x) to keep the op opaque during torch.export. CPU kernel holds the reference math so non-Vulkan backends keep working. op_registry.py pins freqs_cis storage to CONTIGUOUS_BUFFER while leaving x at CONTIGUOUS_ANY.

Differential Revision: D102360202

Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch.

**Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`:

  out[2k]   = x[2k] * cos[k] - x[2k+1] * sin[k]
  out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]

**Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads.

**Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches.

**Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML.

**Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`.

Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/)

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 24, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19115

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 3 Cancelled Jobs, 2 Unrelated Failures

As of commit 9ad2459 with merge base eef7921 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 24, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch.

**Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`:

  out[2k]   = x[2k] * cos[k] - x[2k+1] * sin[k]
  out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]

**Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads.

**Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches.

**Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML.

**Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`.

Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/)

[ghstack-poisoned]
SS-JIA pushed a commit that referenced this pull request Apr 24, 2026
Pull Request resolved: #19115

Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch.

**Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`:

  out[2k]   = x[2k] * cos[k] - x[2k+1] * sin[k]
  out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]

**Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads.

**Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches.

**Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML.

**Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`.
ghstack-source-id: 372548626
@exported-using-ghexport

Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/)
Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch.

**Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`:

  out[2k]   = x[2k] * cos[k] - x[2k+1] * sin[k]
  out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]

**Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads.

**Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches.

**Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML.

**Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`.

Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/)

[ghstack-poisoned]
SS-JIA pushed a commit that referenced this pull request Apr 25, 2026
Pull Request resolved: #19115

Introduces `et_vk.apply_rotary_emb_interleaved`, a fused Vulkan custom operator for the "complex-number" RoPE variant used by SAM2/EdgeTAM's memory attention. This replaces a 12+-op layout-shuffle chain (`view/unbind/stack/view` -> lowers to `slice_copy + squeeze_copy + unsqueeze_copy + cat + view_copy`) with a single GPU dispatch.

**Math**: On pair-interleaved inputs where element `2k` is real and `2k+1` is imag, for each `k in [0, C/2)`:

  out[2k]   = x[2k] * cos[k] - x[2k+1] * sin[k]
  out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]

**Why a new op instead of reusing `et_vk.apply_rotary_emb`**: The existing LLM-oriented operator takes `(xq, xk)` pairs with separate `freqs_cos` / `freqs_sin` tensors and 4D `(B, S, H, D)` shapes optimized for LLM prefill two-texel-per-thread reuse. SAM2's memory attention passes a single 3D `(B, N, C)` tensor through RoPE (no heads dim) with a fused `[N, C/2, 2]` freqs tensor. Reusing the existing op would force runtime splits of the fused freqs and double-dispatch Q/K separately, defeating the fuse. A sibling shader is tighter for both workloads.

**Op contract**: `apply_rotary_emb_interleaved(x, freqs_cis) -> Tensor` where `x` is `[B, N, C]` and `freqs_cis` is any rank with `N*C` elements and the `cos`/`sin` values interleaved on the innermost dim. In EdgeTAM's memory attention the native shape is `[1, N, C/2, 2]`; passing it through without a reshape keeps the exported graph clean of bracketing view_copy dispatches.

**Shader**: Single-dispatch kernel, one texel out per thread. Each thread reads one `x` texel (2 real/imag pairs) and the corresponding `freqs_cis` entries (2 cos/sin pairs) flat-indexed from buffer storage, writes one output texel. `x` and output support buffer + texture3d; `freqs_cis` is always buffer-storage (small tensor, flat indexing is simplest). Supports fp16 and fp32 via the `FP_T` dtype iterator in the YAML.

**Op registration**: `Meta` kernel returns `torch.empty_like(x)` to keep the op opaque during `torch.export`. `CPU` kernel holds the reference math so non-Vulkan backends keep working. `op_registry.py` pins `freqs_cis` storage to `CONTIGUOUS_BUFFER` while leaving `x` at `CONTIGUOUS_ANY`.
ghstack-source-id: 372777610
@exported-using-ghexport

Differential Revision: [D102360202](https://our.internmc.facebook.com/intern/diff/D102360202/)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant