Skip to content

Commit 9ad1392

Browse files
Andrew Grebenisanmeta-codesync[bot]
authored andcommitted
Merge back to back slices on the same dim (pytorch#19120)
Summary: Pull Request resolved: pytorch#19120 If we have back to back slices on the same dimension, we can remove the top slice and just perform the second one. RemoveOrReplacePassInterface will handle a whole cascade if it exists. Reviewed By: abeakkas Differential Revision: D102278253
1 parent e87499f commit 9ad1392

3 files changed

Lines changed: 329 additions & 1 deletion

File tree

backends/cadence/aot/fuse_ops.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
HierarchicalInplacePassInterface,
3232
register_cadence_pass,
3333
RemoveOrReplacePassInterface,
34+
set_arg,
3435
)
3536
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3637
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
@@ -1105,6 +1106,75 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
11051106
return True
11061107

11071108

1109+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1110+
class FuseSliceSameDimPass(RemoveOrReplacePassInterface):
1111+
"""Fuse chained slices on the same dim into a single slice.
1112+
1113+
When a slice_copy's input is another slice_copy on the same dimension
1114+
with step=1, the child slice can read directly from the grandparent
1115+
with merged indices, eliminating the intermediate slice.
1116+
1117+
Handles negative start/end indices by canonicalizing them against the
1118+
relevant dimension size before merging.
1119+
"""
1120+
1121+
@staticmethod
1122+
def _canonicalize(val: int, dim_size: int) -> int:
1123+
return val + dim_size if val < 0 else val
1124+
1125+
@property
1126+
def targets(self) -> list[EdgeOpOverload]:
1127+
return [exir_ops.edge.aten.slice_copy.Tensor]
1128+
1129+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1130+
parent = get_arg(node, "input", torch.fx.Node)
1131+
if parent.target != exir_ops.edge.aten.slice_copy.Tensor:
1132+
return False
1133+
1134+
grandparent = get_arg(parent, "input", torch.fx.Node)
1135+
ndim = len(grandparent.meta["val"].shape)
1136+
child_dim = get_arg(node, "dim", int) % ndim
1137+
parent_dim = get_arg(parent, "dim", int) % ndim
1138+
if child_dim != parent_dim:
1139+
return False
1140+
1141+
child_start = get_arg(node, "start", Optional[int])
1142+
child_end = get_arg(node, "end", Optional[int])
1143+
child_step = get_arg(node, "step", int)
1144+
parent_start = get_arg(parent, "start", Optional[int])
1145+
parent_end = get_arg(parent, "end", Optional[int])
1146+
parent_step = get_arg(parent, "step", int)
1147+
1148+
if child_step != 1 or parent_step != 1:
1149+
return False
1150+
if (
1151+
child_start is None
1152+
or child_end is None
1153+
or parent_start is None
1154+
or parent_end is None
1155+
):
1156+
return False
1157+
1158+
grandparent_dim_size = grandparent.meta["val"].shape[parent_dim]
1159+
parent_dim_size = parent.meta["val"].shape[parent_dim]
1160+
1161+
p_start = self._canonicalize(parent_start, grandparent_dim_size)
1162+
p_end = self._canonicalize(parent_end, grandparent_dim_size)
1163+
c_start = self._canonicalize(child_start, parent_dim_size)
1164+
c_end = self._canonicalize(child_end, parent_dim_size)
1165+
1166+
new_start = p_start + c_start
1167+
new_end = min(p_start + c_end, p_end)
1168+
1169+
if new_end > grandparent_dim_size:
1170+
return False
1171+
1172+
node.replace_input_with(parent, grandparent)
1173+
set_arg(node, "start", new_start)
1174+
set_arg(node, "end", new_end)
1175+
return True
1176+
1177+
11081178
class HierarchicalCSEPass(HierarchicalInplacePassInterface):
11091179
"""
11101180
A hierarchical Common Subexpression Elimination (CSE) pass that recursively
@@ -1138,4 +1208,5 @@ class CadenceFuseOpsInGraph:
11381208
FuseFullThenReshapePass,
11391209
FuseTransposeOrPermuteOpPairsPass,
11401210
FuseMeanKeepDimWithViewPass,
1211+
FuseSliceSameDimPass,
11411212
]

backends/cadence/aot/simplify_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import cast, Optional
1414

1515
import torch
16+
1617
from executorch.backends.cadence.aot.pass_utils import (
1718
CadencePassAttribute,
1819
register_cadence_pass,

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 257 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,15 @@
2626
FuseMulTensorIntoQuantPass,
2727
FuseQuantDequantToRequantizePass,
2828
FuseQuantizedBatchNormWithConv,
29+
FuseSliceSameDimPass,
2930
FuseTransposeOrPermuteOpPairsPass,
3031
HierarchicalCSEPass,
3132
)
32-
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
33+
from executorch.backends.cadence.aot.pass_utils import (
34+
count_node,
35+
get_arg,
36+
op_counts_match,
37+
)
3338
from executorch.backends.cadence.aot.typing_stubs import expand
3439
from executorch.backends.test.graph_builder import GraphBuilder
3540
from executorch.exir.dialects._ops import ops as exir_ops
@@ -1862,3 +1867,254 @@ def test_reduce_single_dim(self) -> None:
18621867
(torch.randn(3, 4, 5),),
18631868
"FuseMeanKeepDimWithViewPass",
18641869
)
1870+
1871+
1872+
class TestFuseSliceSameDimPass(TestFusionPassesBase):
1873+
def _get_single_slice(self, gm: torch.fx.GraphModule) -> torch.fx.Node:
1874+
slices = gm.graph.find_nodes(
1875+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1876+
)
1877+
self.assertEqual(len(slices), 1)
1878+
return slices[0]
1879+
1880+
def test_basic_chain_bypass(self) -> None:
1881+
"""slice(dim=3, 0:78) → slice(dim=3, 0:60) → direct slice(dim=3, 0:60)."""
1882+
builder = GraphBuilder()
1883+
x = builder.placeholder("x", torch.randn(2, 3, 4, 80))
1884+
parent = builder.call_operator(
1885+
op=exir_ops.edge.aten.slice_copy.Tensor,
1886+
args=(x, 3, 0, 78, 1),
1887+
)
1888+
child = builder.call_operator(
1889+
op=exir_ops.edge.aten.slice_copy.Tensor,
1890+
args=(parent, 3, 0, 60, 1),
1891+
)
1892+
builder.output([child])
1893+
original = builder.get_graph_module()
1894+
gm_before = copy.deepcopy(original)
1895+
1896+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1897+
self.assertTrue(result.modified)
1898+
self.assertEqual(
1899+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
1900+
)
1901+
merged = self._get_single_slice(result.graph_module)
1902+
self.assertEqual(get_arg(merged, "start"), 0)
1903+
self.assertEqual(get_arg(merged, "end"), 60)
1904+
validate_numerics(
1905+
gm_before,
1906+
result.graph_module,
1907+
(torch.randn(2, 3, 4, 80),),
1908+
"FuseSliceSameDimPass",
1909+
)
1910+
1911+
def test_chain_with_offset(self) -> None:
1912+
"""slice(dim=1, 10:50) → slice(dim=1, 5:20) → direct slice(dim=1, 15:30)."""
1913+
builder = GraphBuilder()
1914+
x = builder.placeholder("x", torch.randn(4, 64))
1915+
parent = builder.call_operator(
1916+
op=exir_ops.edge.aten.slice_copy.Tensor,
1917+
args=(x, 1, 10, 50, 1),
1918+
)
1919+
child = builder.call_operator(
1920+
op=exir_ops.edge.aten.slice_copy.Tensor,
1921+
args=(parent, 1, 5, 20, 1),
1922+
)
1923+
builder.output([child])
1924+
original = builder.get_graph_module()
1925+
gm_before = copy.deepcopy(original)
1926+
1927+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1928+
self.assertTrue(result.modified)
1929+
self.assertEqual(
1930+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
1931+
)
1932+
merged = self._get_single_slice(result.graph_module)
1933+
self.assertEqual(get_arg(merged, "start"), 15)
1934+
self.assertEqual(get_arg(merged, "end"), 30)
1935+
validate_numerics(
1936+
gm_before,
1937+
result.graph_module,
1938+
(torch.randn(4, 64),),
1939+
"FuseSliceSameDimPass",
1940+
)
1941+
1942+
def test_parent_kept_with_other_users(self) -> None:
1943+
"""Parent slice has another user besides the child → parent stays."""
1944+
builder = GraphBuilder()
1945+
x = builder.placeholder("x", torch.randn(2, 3, 4, 80))
1946+
parent = builder.call_operator(
1947+
op=exir_ops.edge.aten.slice_copy.Tensor,
1948+
args=(x, 3, 0, 78, 1),
1949+
)
1950+
child = builder.call_operator(
1951+
op=exir_ops.edge.aten.slice_copy.Tensor,
1952+
args=(parent, 3, 0, 60, 1),
1953+
)
1954+
neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(parent,))
1955+
builder.output([child, neg])
1956+
original = builder.get_graph_module()
1957+
gm_before = copy.deepcopy(original)
1958+
1959+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1960+
self.assertTrue(result.modified)
1961+
self.assertEqual(
1962+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 2
1963+
)
1964+
slices = result.graph_module.graph.find_nodes(
1965+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1966+
)
1967+
ends = sorted(get_arg(s, "end") for s in slices)
1968+
self.assertEqual(ends, [60, 78])
1969+
validate_numerics(
1970+
gm_before,
1971+
result.graph_module,
1972+
(torch.randn(2, 3, 4, 80),),
1973+
"FuseSliceSameDimPass",
1974+
)
1975+
1976+
def test_different_dims_no_change(self) -> None:
1977+
"""Chained slices on different dims → no change."""
1978+
builder = GraphBuilder()
1979+
x = builder.placeholder("x", torch.randn(8, 16, 32))
1980+
parent = builder.call_operator(
1981+
op=exir_ops.edge.aten.slice_copy.Tensor,
1982+
args=(x, 1, 0, 10, 1),
1983+
)
1984+
child = builder.call_operator(
1985+
op=exir_ops.edge.aten.slice_copy.Tensor,
1986+
args=(parent, 2, 0, 5, 1),
1987+
)
1988+
builder.output([child])
1989+
original = builder.get_graph_module()
1990+
1991+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1992+
self.assertFalse(result.modified)
1993+
1994+
def test_step_not_one_no_change(self) -> None:
1995+
"""Parent has step != 1 → no change."""
1996+
builder = GraphBuilder()
1997+
x = builder.placeholder("x", torch.randn(4, 64))
1998+
parent = builder.call_operator(
1999+
op=exir_ops.edge.aten.slice_copy.Tensor,
2000+
args=(x, 1, 0, 60, 2),
2001+
)
2002+
child = builder.call_operator(
2003+
op=exir_ops.edge.aten.slice_copy.Tensor,
2004+
args=(parent, 1, 0, 10, 1),
2005+
)
2006+
builder.output([child])
2007+
original = builder.get_graph_module()
2008+
2009+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2010+
self.assertFalse(result.modified)
2011+
2012+
def test_no_chain_no_change(self) -> None:
2013+
"""Single slice with no slice user → no change."""
2014+
builder = GraphBuilder()
2015+
x = builder.placeholder("x", torch.randn(4, 64))
2016+
sliced = builder.call_operator(
2017+
op=exir_ops.edge.aten.slice_copy.Tensor,
2018+
args=(x, 1, 0, 32, 1),
2019+
)
2020+
builder.output([sliced])
2021+
original = builder.get_graph_module()
2022+
2023+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2024+
self.assertFalse(result.modified)
2025+
2026+
def test_child_end_clamped_to_parent_range(self) -> None:
2027+
"""Child end exceeds parent output size → clamped to parent_end."""
2028+
builder = GraphBuilder()
2029+
x = builder.placeholder("x", torch.randn(1, 100))
2030+
parent = builder.call_operator(
2031+
op=exir_ops.edge.aten.slice_copy.Tensor,
2032+
args=(x, 1, 10, 50, 1),
2033+
)
2034+
child = builder.call_operator(
2035+
op=exir_ops.edge.aten.slice_copy.Tensor,
2036+
args=(parent, 1, 5, 45, 1),
2037+
)
2038+
builder.output([child])
2039+
original = builder.get_graph_module()
2040+
gm_before = copy.deepcopy(original)
2041+
2042+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2043+
self.assertTrue(result.modified)
2044+
self.assertEqual(
2045+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
2046+
)
2047+
merged = self._get_single_slice(result.graph_module)
2048+
self.assertEqual(get_arg(merged, "start"), 15)
2049+
self.assertEqual(get_arg(merged, "end"), 50)
2050+
validate_numerics(
2051+
gm_before,
2052+
result.graph_module,
2053+
(torch.randn(1, 100),),
2054+
"FuseSliceSameDimPass",
2055+
)
2056+
2057+
def test_negative_indices(self) -> None:
2058+
"""Negative start/end are canonicalized before merging."""
2059+
# Parent: slice(dim=1, 10:-10) on size 100 → [10:90], output size 80.
2060+
# Child: slice(dim=1, 5:-5) on size 80 → [5:75], output size 70.
2061+
# Merged: [15:85].
2062+
builder = GraphBuilder()
2063+
x = builder.placeholder("x", torch.randn(1, 100))
2064+
parent = builder.call_operator(
2065+
op=exir_ops.edge.aten.slice_copy.Tensor,
2066+
args=(x, 1, 10, -10, 1),
2067+
)
2068+
child = builder.call_operator(
2069+
op=exir_ops.edge.aten.slice_copy.Tensor,
2070+
args=(parent, 1, 5, -5, 1),
2071+
)
2072+
builder.output([child])
2073+
original = builder.get_graph_module()
2074+
gm_before = copy.deepcopy(original)
2075+
2076+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2077+
self.assertTrue(result.modified)
2078+
self.assertEqual(
2079+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
2080+
)
2081+
merged = self._get_single_slice(result.graph_module)
2082+
self.assertEqual(get_arg(merged, "start"), 15)
2083+
self.assertEqual(get_arg(merged, "end"), 85)
2084+
validate_numerics(
2085+
gm_before,
2086+
result.graph_module,
2087+
(torch.randn(1, 100),),
2088+
"FuseSliceSameDimPass",
2089+
)
2090+
2091+
def test_negative_dim(self) -> None:
2092+
"""Negative dim is canonicalized so matching works across conventions."""
2093+
builder = GraphBuilder()
2094+
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
2095+
parent = builder.call_operator(
2096+
op=exir_ops.edge.aten.slice_copy.Tensor,
2097+
args=(x, -1, 0, 4, 1),
2098+
)
2099+
child = builder.call_operator(
2100+
op=exir_ops.edge.aten.slice_copy.Tensor,
2101+
args=(parent, 3, 0, 2, 1),
2102+
)
2103+
builder.output([child])
2104+
original = builder.get_graph_module()
2105+
gm_before = copy.deepcopy(original)
2106+
2107+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2108+
self.assertTrue(result.modified)
2109+
self.assertEqual(
2110+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
2111+
)
2112+
merged = self._get_single_slice(result.graph_module)
2113+
self.assertEqual(get_arg(merged, "start"), 0)
2114+
self.assertEqual(get_arg(merged, "end"), 2)
2115+
validate_numerics(
2116+
gm_before,
2117+
result.graph_module,
2118+
(torch.randn(2, 3, 4, 5),),
2119+
"FuseSliceSameDimPass",
2120+
)

0 commit comments

Comments
 (0)