Which component requires the feature?
CUTLASS C++
Feature Request
Is your feature request related to a problem? Please describe.
CUTLASS provides several built-in activation functors (GELU, ReLU, SiLU, etc.) that can be used inside the Epilogue Visitor Tree framework. However, there is currently no support for the Snake activation function (x + sin²(αx) / α, Ziyin et al. 2020), a parametric activation that has seen adoption in audio synthesis models (neural vocoders) and other domains where periodic inductive bias is beneficial.
Users who need Snake activation in fused GEMM epilogues currently have to fall back to unfused workflows, running the activation as a separate kernel after GEMM, which incurs additional global memory round-trips.
Describe the solution you'd like
Add a SnakeOp functor to CUTLASS's activation function library, following the same pattern as existing functors like GELU. The implementation would:
- Support scalar and
Array<T, N> specializations
- Take a learnable per-channel
α parameter via Sm90RowBroadcast
- Be composable within the SM90 EVT framework (e.g.
Sm90Compute<SnakeOp, ...>)
I have a working implementation targeting SM90 (Hopper) that I'd be happy to contribute as a PR. The EVT tree structure is:
// EVT tree: SnakeOp(AccFetch, RowBroadcast(alpha))
using SnakeEpilogue = Sm90EVT<
Sm90Compute<SnakeOp, ElementOut, ElementEpi,
cutlass::FloatRoundStyle::round_to_nearest>,
Sm90AccFetch,
Sm90RowBroadcast<0, TileShape, float>
>;
Describe alternatives you've considered
- Unfused approach: Running Snake activation as a separate PyTorch/custom CUDA kernel after the GEMM. This works but requires an extra global memory read/write of the full output tensor.
- Approximating with existing functors: Snake cannot be accurately decomposed into existing CUTLASS activation primitives since it requires
sin² and a per-channel learnable parameter.
Additional context
Benchmarks (H100 SXM5 80GB, SM90)
Implicit GEMM shapes from a production neural vocoder (Conv1d k=7 mapped to GEMM as M=T, N=C_out, K=C_in×7):
| GEMM Shape (M × N × K) |
cuDNN Conv + Separate Snake |
CUTLASS Fused EVT |
Speedup |
| 1200 × 768 × 5376 |
0.108 – 0.138 ms |
0.078 ms |
1.39 – 1.76× |
| 6000 × 384 × 2688 |
0.125 – 0.664 ms |
0.078 ms |
1.61 – 8.51× |
| 24000 × 192 × 1344 |
0.174 – 0.193 ms |
0.082 ms |
2.12 – 2.35× |
| 48000 × 96 × 672 |
0.168 – 0.170 ms |
0.063 – 0.064 ms |
2.65 – 2.70× |
Median speedup across all 12 shapes is approximately 2.1×, with the benefit primarily coming from eliminating the extra memory round-trip for the activation kernel.
Implementation
The implementation follows CUTLASS's existing functor conventions (scalar + Array<T, N> specializations, __sinf fast math intrinsic). I'm ready to open a PR if there's interest.
Reference
Ziyin, L., Hartwig, T., & Ueda, M. "Neural Networks Fail to Learn Periodic Functions and How to Fix It." NeurIPS 2020. arXiv:2006.08195 , introduces x + (1/a) sin²(ax).
# [Feature Request] Add Snake activation functor for Epilogue Visitor Tree (EVT)
Is your feature request related to a problem? Please describe.
CUTLASS provides several built-in activation functors (GELU, ReLU, SiLU, etc.) that can be used inside the Epilogue Visitor Tree framework. However, there is currently no support for the Snake activation function (x + sin²(αx) / α, Ziyin et al. 2020), a parametric activation that has seen adoption in audio synthesis models (neural vocoders) and other domains where periodic inductive bias is beneficial.
Users who need Snake activation in fused GEMM epilogues currently have to fall back to unfused workflows, running the activation as a separate kernel after GEMM, which incurs additional global memory round-trips.
Describe the solution you'd like
Add a SnakeOp functor to CUTLASS's activation function library, following the same pattern as existing functors like GELU. The implementation would:
- Support scalar and
Array<T, N> specializations
- Take a learnable per-channel
α parameter via Sm90RowBroadcast
- Be composable within the Sm90 EVT framework (e.g.
Sm90Compute<SnakeOp, ...>)
I have a working implementation targeting SM90 (Hopper) that I'd be happy to contribute as a PR. The EVT tree structure is:
// EVT tree: SnakeOp(AccFetch, RowBroadcast(alpha))
using SnakeEpilogue = Sm90EVT<
Sm90Compute<SnakeOp, ElementOut, ElementEpi,
cutlass::FloatRoundStyle::round_to_nearest>,
Sm90AccFetch,
Sm90RowBroadcast<0, TileShape, float>
>;
Describe alternatives you've considered
- Unfused approach: Running Snake activation as a separate PyTorch/custom CUDA kernel after the GEMM. This works but requires an extra global memory read/write of the full output tensor.
- Approximating with existing functors: Snake cannot be accurately decomposed into existing CUTLASS activation primitives since it requires
sin² and a per-channel learnable parameter.
Additional context
Benchmarks (H100 SXM5 80GB, SM90)
Implicit GEMM shapes from a production neural vocoder (Conv1d k=7 mapped to GEMM as M=T, N=C_out, K=C_in×7):
| GEMM Shape (M × N × K) |
cuDNN Conv + Separate Snake |
CUTLASS Fused EVT |
Speedup |
| 1200 × 768 × 5376 |
0.108 – 0.138 ms |
0.078 ms |
1.39 – 1.76× |
| 6000 × 384 × 2688 |
0.125 – 0.664 ms |
0.078 ms |
1.61 – 8.51× |
| 24000 × 192 × 1344 |
0.174 – 0.193 ms |
0.082 ms |
2.12 – 2.35× |
| 48000 × 96 × 672 |
0.168 – 0.170 ms |
0.063 – 0.064 ms |
2.65 – 2.70× |
Median speedup across all 12 shapes is approximately 2.1×, with the benefit primarily coming from eliminating the extra memory round-trip for the activation kernel.
Implementation
The implementation follows CUTLASS's existing functor conventions (scalar + Array<T, N> specializations, __sinf fast math intrinsic). I'm ready to open a PR if there's interest.
Reference
Ziyin, L., Hartwig, T., & Ueda, M. "Neural Networks Fail to Learn Periodic Functions and How to Fix It." NeurIPS 2020. [arXiv:2006.08195](https://arxiv.org/abs/2006.08195), introduces x + (1/a) sin²(ax).
Which component requires the feature?
CUTLASS C++
Feature Request
Is your feature request related to a problem? Please describe.
CUTLASS provides several built-in activation functors (GELU, ReLU, SiLU, etc.) that can be used inside the Epilogue Visitor Tree framework. However, there is currently no support for the Snake activation function (
x + sin²(αx) / α, Ziyin et al. 2020), a parametric activation that has seen adoption in audio synthesis models (neural vocoders) and other domains where periodic inductive bias is beneficial.Users who need Snake activation in fused GEMM epilogues currently have to fall back to unfused workflows, running the activation as a separate kernel after GEMM, which incurs additional global memory round-trips.
Describe the solution you'd like
Add a
SnakeOpfunctor to CUTLASS's activation function library, following the same pattern as existing functors likeGELU. The implementation would:Array<T, N>specializationsαparameter viaSm90RowBroadcastSm90Compute<SnakeOp, ...>)I have a working implementation targeting SM90 (Hopper) that I'd be happy to contribute as a PR. The EVT tree structure is:
Describe alternatives you've considered
sin²and a per-channel learnable parameter.Additional context
Benchmarks (H100 SXM5 80GB, SM90)
Implicit GEMM shapes from a production neural vocoder (Conv1d k=7 mapped to GEMM as M=T, N=C_out, K=C_in×7):
Median speedup across all 12 shapes is approximately 2.1×, with the benefit primarily coming from eliminating the extra memory round-trip for the activation kernel.
Implementation
The implementation follows CUTLASS's existing functor conventions (scalar +
Array<T, N>specializations,__sinffast math intrinsic). I'm ready to open a PR if there's interest.Reference
Ziyin, L., Hartwig, T., & Ueda, M. "Neural Networks Fail to Learn Periodic Functions and How to Fix It." NeurIPS 2020. arXiv:2006.08195 , introduces
# [Feature Request] Add Snake activation functor for Epilogue Visitor Tree (EVT)x + (1/a) sin²(ax).Is your feature request related to a problem? Please describe.
CUTLASS provides several built-in activation functors (GELU, ReLU, SiLU, etc.) that can be used inside the Epilogue Visitor Tree framework. However, there is currently no support for the Snake activation function (
x + sin²(αx) / α, Ziyin et al. 2020), a parametric activation that has seen adoption in audio synthesis models (neural vocoders) and other domains where periodic inductive bias is beneficial.Users who need Snake activation in fused GEMM epilogues currently have to fall back to unfused workflows, running the activation as a separate kernel after GEMM, which incurs additional global memory round-trips.
Describe the solution you'd like
Add a
SnakeOpfunctor to CUTLASS's activation function library, following the same pattern as existing functors likeGELU. The implementation would:Array<T, N>specializationsαparameter viaSm90RowBroadcastSm90Compute<SnakeOp, ...>)I have a working implementation targeting SM90 (Hopper) that I'd be happy to contribute as a PR. The EVT tree structure is:
Describe alternatives you've considered
sin²and a per-channel learnable parameter.Additional context
Benchmarks (H100 SXM5 80GB, SM90)
Implicit GEMM shapes from a production neural vocoder (Conv1d k=7 mapped to GEMM as M=T, N=C_out, K=C_in×7):
Median speedup across all 12 shapes is approximately 2.1×, with the benefit primarily coming from eliminating the extra memory round-trip for the activation kernel.
Implementation
The implementation follows CUTLASS's existing functor conventions (scalar +
Array<T, N>specializations,__sinffast math intrinsic). I'm ready to open a PR if there's interest.Reference
Ziyin, L., Hartwig, T., & Ueda, M. "Neural Networks Fail to Learn Periodic Functions and How to Fix It." NeurIPS 2020. [arXiv:2006.08195](https://arxiv.org/abs/2006.08195), introduces
x + (1/a) sin²(ax).