Skip to content

Commit fdb522f

Browse files
Andrew Grebenisanmeta-codesync[bot]
authored andcommitted
Merge back to back slices on the same dim (pytorch#19120)
Summary: Pull Request resolved: pytorch#19120 If we have back to back slices on the same dimension, we can remove the top slice and just perform the second one. RemoveOrReplacePassInterface will handle a whole cascade if it exists. Reviewed By: abeakkas Differential Revision: D102278253
1 parent e87499f commit fdb522f

3 files changed

Lines changed: 327 additions & 1 deletion

File tree

backends/cadence/aot/fuse_ops.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
HierarchicalInplacePassInterface,
3232
register_cadence_pass,
3333
RemoveOrReplacePassInterface,
34+
set_arg,
3435
)
3536
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3637
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
@@ -1105,6 +1106,75 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
11051106
return True
11061107

11071108

1109+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1110+
class FuseSliceSameDimPass(RemoveOrReplacePassInterface):
1111+
"""Fuse chained slices on the same dim into a single slice.
1112+
1113+
When a slice_copy's input is another slice_copy on the same dimension
1114+
with step=1, the child slice can read directly from the grandparent
1115+
with merged indices, eliminating the intermediate slice.
1116+
1117+
Handles negative start/end indices by canonicalizing them against the
1118+
relevant dimension size before merging.
1119+
"""
1120+
1121+
@staticmethod
1122+
def _canonicalize(val: int, dim_size: int) -> int:
1123+
return val + dim_size if val < 0 else val
1124+
1125+
@property
1126+
def targets(self) -> list[EdgeOpOverload]:
1127+
return [exir_ops.edge.aten.slice_copy.Tensor]
1128+
1129+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1130+
parent = get_arg(node, "input", torch.fx.Node)
1131+
if parent.target != exir_ops.edge.aten.slice_copy.Tensor:
1132+
return False
1133+
1134+
grandparent = get_arg(parent, "input", torch.fx.Node)
1135+
ndim = len(grandparent.meta["val"].shape)
1136+
child_dim = get_arg(node, "dim", int) % ndim
1137+
parent_dim = get_arg(parent, "dim", int) % ndim
1138+
if child_dim != parent_dim:
1139+
return False
1140+
1141+
child_start = get_arg(node, "start", Optional[int])
1142+
child_end = get_arg(node, "end", Optional[int])
1143+
child_step = get_arg(node, "step", int)
1144+
parent_start = get_arg(parent, "start", Optional[int])
1145+
parent_end = get_arg(parent, "end", Optional[int])
1146+
parent_step = get_arg(parent, "step", int)
1147+
1148+
if child_step != 1 or parent_step != 1:
1149+
return False
1150+
if (
1151+
child_start is None
1152+
or child_end is None
1153+
or parent_start is None
1154+
or parent_end is None
1155+
):
1156+
return False
1157+
1158+
grandparent_dim_size = grandparent.meta["val"].shape[parent_dim]
1159+
parent_dim_size = parent.meta["val"].shape[parent_dim]
1160+
1161+
p_start = self._canonicalize(parent_start, grandparent_dim_size)
1162+
p_end = self._canonicalize(parent_end, grandparent_dim_size)
1163+
c_start = self._canonicalize(child_start, parent_dim_size)
1164+
c_end = self._canonicalize(child_end, parent_dim_size)
1165+
1166+
new_start = p_start + c_start
1167+
new_end = min(p_start + c_end, p_end)
1168+
1169+
if new_end > grandparent_dim_size:
1170+
return False
1171+
1172+
node.replace_input_with(parent, grandparent)
1173+
set_arg(node, "start", new_start)
1174+
set_arg(node, "end", new_end)
1175+
return True
1176+
1177+
11081178
class HierarchicalCSEPass(HierarchicalInplacePassInterface):
11091179
"""
11101180
A hierarchical Common Subexpression Elimination (CSE) pass that recursively
@@ -1138,4 +1208,5 @@ class CadenceFuseOpsInGraph:
11381208
FuseFullThenReshapePass,
11391209
FuseTransposeOrPermuteOpPairsPass,
11401210
FuseMeanKeepDimWithViewPass,
1211+
FuseSliceSameDimPass,
11411212
]

backends/cadence/aot/simplify_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import cast, Optional
1414

1515
import torch
16+
1617
from executorch.backends.cadence.aot.pass_utils import (
1718
CadencePassAttribute,
1819
register_cadence_pass,

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 255 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
FuseMulTensorIntoQuantPass,
2727
FuseQuantDequantToRequantizePass,
2828
FuseQuantizedBatchNormWithConv,
29+
FuseSliceSameDimPass,
2930
FuseTransposeOrPermuteOpPairsPass,
3031
HierarchicalCSEPass,
3132
)
32-
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
33+
from executorch.backends.cadence.aot.pass_utils import count_node, get_arg, op_counts_match
3334
from executorch.backends.cadence.aot.typing_stubs import expand
3435
from executorch.backends.test.graph_builder import GraphBuilder
3536
from executorch.exir.dialects._ops import ops as exir_ops
@@ -1862,3 +1863,256 @@ def test_reduce_single_dim(self) -> None:
18621863
(torch.randn(3, 4, 5),),
18631864
"FuseMeanKeepDimWithViewPass",
18641865
)
1866+
1867+
1868+
class TestFuseSliceSameDimPass(TestFusionPassesBase):
1869+
def _get_single_slice(
1870+
self, gm: torch.fx.GraphModule
1871+
) -> torch.fx.Node:
1872+
slices = gm.graph.find_nodes(
1873+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1874+
)
1875+
self.assertEqual(len(slices), 1)
1876+
return slices[0]
1877+
1878+
def test_basic_chain_bypass(self) -> None:
1879+
"""slice(dim=3, 0:78) → slice(dim=3, 0:60) → direct slice(dim=3, 0:60)."""
1880+
builder = GraphBuilder()
1881+
x = builder.placeholder("x", torch.randn(2, 3, 4, 80))
1882+
parent = builder.call_operator(
1883+
op=exir_ops.edge.aten.slice_copy.Tensor,
1884+
args=(x, 3, 0, 78, 1),
1885+
)
1886+
child = builder.call_operator(
1887+
op=exir_ops.edge.aten.slice_copy.Tensor,
1888+
args=(parent, 3, 0, 60, 1),
1889+
)
1890+
builder.output([child])
1891+
original = builder.get_graph_module()
1892+
gm_before = copy.deepcopy(original)
1893+
1894+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1895+
self.assertTrue(result.modified)
1896+
self.assertEqual(
1897+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
1898+
)
1899+
merged = self._get_single_slice(result.graph_module)
1900+
self.assertEqual(get_arg(merged, "start"), 0)
1901+
self.assertEqual(get_arg(merged, "end"), 60)
1902+
validate_numerics(
1903+
gm_before,
1904+
result.graph_module,
1905+
(torch.randn(2, 3, 4, 80),),
1906+
"FuseSliceSameDimPass",
1907+
)
1908+
1909+
def test_chain_with_offset(self) -> None:
1910+
"""slice(dim=1, 10:50) → slice(dim=1, 5:20) → direct slice(dim=1, 15:30)."""
1911+
builder = GraphBuilder()
1912+
x = builder.placeholder("x", torch.randn(4, 64))
1913+
parent = builder.call_operator(
1914+
op=exir_ops.edge.aten.slice_copy.Tensor,
1915+
args=(x, 1, 10, 50, 1),
1916+
)
1917+
child = builder.call_operator(
1918+
op=exir_ops.edge.aten.slice_copy.Tensor,
1919+
args=(parent, 1, 5, 20, 1),
1920+
)
1921+
builder.output([child])
1922+
original = builder.get_graph_module()
1923+
gm_before = copy.deepcopy(original)
1924+
1925+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1926+
self.assertTrue(result.modified)
1927+
self.assertEqual(
1928+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
1929+
)
1930+
merged = self._get_single_slice(result.graph_module)
1931+
self.assertEqual(get_arg(merged, "start"), 15)
1932+
self.assertEqual(get_arg(merged, "end"), 30)
1933+
validate_numerics(
1934+
gm_before,
1935+
result.graph_module,
1936+
(torch.randn(4, 64),),
1937+
"FuseSliceSameDimPass",
1938+
)
1939+
1940+
def test_parent_kept_with_other_users(self) -> None:
1941+
"""Parent slice has another user besides the child → parent stays."""
1942+
builder = GraphBuilder()
1943+
x = builder.placeholder("x", torch.randn(2, 3, 4, 80))
1944+
parent = builder.call_operator(
1945+
op=exir_ops.edge.aten.slice_copy.Tensor,
1946+
args=(x, 3, 0, 78, 1),
1947+
)
1948+
child = builder.call_operator(
1949+
op=exir_ops.edge.aten.slice_copy.Tensor,
1950+
args=(parent, 3, 0, 60, 1),
1951+
)
1952+
neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(parent,))
1953+
builder.output([child, neg])
1954+
original = builder.get_graph_module()
1955+
gm_before = copy.deepcopy(original)
1956+
1957+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1958+
self.assertTrue(result.modified)
1959+
self.assertEqual(
1960+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 2
1961+
)
1962+
slices = result.graph_module.graph.find_nodes(
1963+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1964+
)
1965+
ends = sorted(get_arg(s, "end") for s in slices)
1966+
self.assertEqual(ends, [60, 78])
1967+
validate_numerics(
1968+
gm_before,
1969+
result.graph_module,
1970+
(torch.randn(2, 3, 4, 80),),
1971+
"FuseSliceSameDimPass",
1972+
)
1973+
1974+
def test_different_dims_no_change(self) -> None:
1975+
"""Chained slices on different dims → no change."""
1976+
builder = GraphBuilder()
1977+
x = builder.placeholder("x", torch.randn(8, 16, 32))
1978+
parent = builder.call_operator(
1979+
op=exir_ops.edge.aten.slice_copy.Tensor,
1980+
args=(x, 1, 0, 10, 1),
1981+
)
1982+
child = builder.call_operator(
1983+
op=exir_ops.edge.aten.slice_copy.Tensor,
1984+
args=(parent, 2, 0, 5, 1),
1985+
)
1986+
builder.output([child])
1987+
original = builder.get_graph_module()
1988+
1989+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1990+
self.assertFalse(result.modified)
1991+
1992+
def test_step_not_one_no_change(self) -> None:
1993+
"""Parent has step != 1 → no change."""
1994+
builder = GraphBuilder()
1995+
x = builder.placeholder("x", torch.randn(4, 64))
1996+
parent = builder.call_operator(
1997+
op=exir_ops.edge.aten.slice_copy.Tensor,
1998+
args=(x, 1, 0, 60, 2),
1999+
)
2000+
child = builder.call_operator(
2001+
op=exir_ops.edge.aten.slice_copy.Tensor,
2002+
args=(parent, 1, 0, 10, 1),
2003+
)
2004+
builder.output([child])
2005+
original = builder.get_graph_module()
2006+
2007+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2008+
self.assertFalse(result.modified)
2009+
2010+
def test_no_chain_no_change(self) -> None:
2011+
"""Single slice with no slice user → no change."""
2012+
builder = GraphBuilder()
2013+
x = builder.placeholder("x", torch.randn(4, 64))
2014+
sliced = builder.call_operator(
2015+
op=exir_ops.edge.aten.slice_copy.Tensor,
2016+
args=(x, 1, 0, 32, 1),
2017+
)
2018+
builder.output([sliced])
2019+
original = builder.get_graph_module()
2020+
2021+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2022+
self.assertFalse(result.modified)
2023+
2024+
def test_child_end_clamped_to_parent_range(self) -> None:
2025+
"""Child end exceeds parent output size → clamped to parent_end."""
2026+
builder = GraphBuilder()
2027+
x = builder.placeholder("x", torch.randn(1, 100))
2028+
parent = builder.call_operator(
2029+
op=exir_ops.edge.aten.slice_copy.Tensor,
2030+
args=(x, 1, 10, 50, 1),
2031+
)
2032+
child = builder.call_operator(
2033+
op=exir_ops.edge.aten.slice_copy.Tensor,
2034+
args=(parent, 1, 5, 45, 1),
2035+
)
2036+
builder.output([child])
2037+
original = builder.get_graph_module()
2038+
gm_before = copy.deepcopy(original)
2039+
2040+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2041+
self.assertTrue(result.modified)
2042+
self.assertEqual(
2043+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
2044+
)
2045+
merged = self._get_single_slice(result.graph_module)
2046+
self.assertEqual(get_arg(merged, "start"), 15)
2047+
self.assertEqual(get_arg(merged, "end"), 50)
2048+
validate_numerics(
2049+
gm_before,
2050+
result.graph_module,
2051+
(torch.randn(1, 100),),
2052+
"FuseSliceSameDimPass",
2053+
)
2054+
2055+
def test_negative_indices(self) -> None:
2056+
"""Negative start/end are canonicalized before merging."""
2057+
# Parent: slice(dim=1, 10:-10) on size 100 → [10:90], output size 80.
2058+
# Child: slice(dim=1, 5:-5) on size 80 → [5:75], output size 70.
2059+
# Merged: [15:85].
2060+
builder = GraphBuilder()
2061+
x = builder.placeholder("x", torch.randn(1, 100))
2062+
parent = builder.call_operator(
2063+
op=exir_ops.edge.aten.slice_copy.Tensor,
2064+
args=(x, 1, 10, -10, 1),
2065+
)
2066+
child = builder.call_operator(
2067+
op=exir_ops.edge.aten.slice_copy.Tensor,
2068+
args=(parent, 1, 5, -5, 1),
2069+
)
2070+
builder.output([child])
2071+
original = builder.get_graph_module()
2072+
gm_before = copy.deepcopy(original)
2073+
2074+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2075+
self.assertTrue(result.modified)
2076+
self.assertEqual(
2077+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
2078+
)
2079+
merged = self._get_single_slice(result.graph_module)
2080+
self.assertEqual(get_arg(merged, "start"), 15)
2081+
self.assertEqual(get_arg(merged, "end"), 85)
2082+
validate_numerics(
2083+
gm_before,
2084+
result.graph_module,
2085+
(torch.randn(1, 100),),
2086+
"FuseSliceSameDimPass",
2087+
)
2088+
2089+
def test_negative_dim(self) -> None:
2090+
"""Negative dim is canonicalized so matching works across conventions."""
2091+
builder = GraphBuilder()
2092+
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
2093+
parent = builder.call_operator(
2094+
op=exir_ops.edge.aten.slice_copy.Tensor,
2095+
args=(x, -1, 0, 4, 1),
2096+
)
2097+
child = builder.call_operator(
2098+
op=exir_ops.edge.aten.slice_copy.Tensor,
2099+
args=(parent, 3, 0, 2, 1),
2100+
)
2101+
builder.output([child])
2102+
original = builder.get_graph_module()
2103+
gm_before = copy.deepcopy(original)
2104+
2105+
result = cast(PassResult, FuseSliceSameDimPass()(original))
2106+
self.assertTrue(result.modified)
2107+
self.assertEqual(
2108+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
2109+
)
2110+
merged = self._get_single_slice(result.graph_module)
2111+
self.assertEqual(get_arg(merged, "start"), 0)
2112+
self.assertEqual(get_arg(merged, "end"), 2)
2113+
validate_numerics(
2114+
gm_before,
2115+
result.graph_module,
2116+
(torch.randn(2, 3, 4, 5),),
2117+
"FuseSliceSameDimPass",
2118+
)

0 commit comments

Comments
 (0)