Skip to content

Commit 18a4c22

Browse files
Andrew Grebenisanmeta-codesync[bot]
authored andcommitted
Merge back to back slices on the same dim
Summary: If we have back to back slices on the same dimension, we can remove the top slice and just perform the second one. Differential Revision: D102278253
1 parent 2095190 commit 18a4c22

2 files changed

Lines changed: 195 additions & 1 deletion

File tree

backends/cadence/aot/simplify_ops.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@
1313
from typing import cast, Optional
1414

1515
import torch
16+
1617
from executorch.backends.cadence.aot.pass_utils import (
1718
CadencePassAttribute,
19+
get_arg,
1820
register_cadence_pass,
1921
RemoveOrReplacePassInterface,
22+
set_arg,
2023
)
2124
from executorch.backends.cadence.aot.utils import rebind
2225
from executorch.exir.dialects._ops import ops as exir_ops
@@ -174,6 +177,61 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
174177
return True
175178

176179

180+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
181+
class BypassRedundantSliceChainPass(RemoveOrReplacePassInterface):
182+
"""Bypass chained slices on the same dim by slicing directly from the source.
183+
184+
When a slice_copy is followed by another slice_copy on the same dimension
185+
with step=1, the child can read directly from the parent's input with
186+
merged indices, eliminating the intermediate slice.
187+
"""
188+
189+
@property
190+
def targets(self) -> list[EdgeOpOverload]:
191+
return [exir_ops.edge.aten.slice_copy.Tensor]
192+
193+
def maybe_remove_or_replace(self, node: Node) -> bool:
194+
parent_input = get_arg(node, "input", Node)
195+
parent_dim = get_arg(node, "dim", int)
196+
parent_start = get_arg(node, "start", Optional[int])
197+
parent_end = get_arg(node, "end", Optional[int])
198+
parent_step = get_arg(node, "step", int)
199+
200+
if parent_step != 1 or parent_start is None or parent_end is None:
201+
return False
202+
203+
input_shape = parent_input.meta["val"].shape
204+
205+
modified = False
206+
for child in list(node.users):
207+
if child.target != exir_ops.edge.aten.slice_copy.Tensor:
208+
continue
209+
210+
child_dim = get_arg(child, "dim", int)
211+
if child_dim != parent_dim:
212+
continue
213+
214+
child_start = get_arg(child, "start", Optional[int])
215+
child_end = get_arg(child, "end", Optional[int])
216+
child_step = get_arg(child, "step", int)
217+
218+
if child_step != 1 or child_start is None or child_end is None:
219+
continue
220+
221+
new_start = parent_start + child_start
222+
new_end = parent_start + child_end
223+
224+
if new_end > input_shape[parent_dim]:
225+
continue
226+
227+
child.replace_input_with(node, parent_input)
228+
set_arg(child, "start", new_start)
229+
set_arg(child, "end", new_end)
230+
modified = True
231+
232+
return modified
233+
234+
177235
@register_cadence_pass(CadencePassAttribute(opt_level=0))
178236
class BindOptionalArgsPass(ExportPass):
179237
"""Bind all optional args and kwargs."""
@@ -217,5 +275,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
217275
class CadenceSimplifyOpsInGraph:
218276
passes = [
219277
SimplifySliceOpPass,
278+
BypassRedundantSliceChainPass,
220279
BindOptionalArgsPass,
221280
]

backends/cadence/aot/tests/test_simplify_ops_passes.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from executorch.backends.cadence.aot.pass_utils import count_node
1616
from executorch.backends.cadence.aot.simplify_ops import (
1717
BindOptionalArgsPass,
18+
BypassRedundantSliceChainPass,
1819
SimplifySliceOpPass,
1920
)
2021
from executorch.backends.cadence.aot.typing_stubs import expand
21-
from executorch.backends.test.graph_builder import single_op_builder
22+
from executorch.backends.test.graph_builder import GraphBuilder, single_op_builder
2223
from executorch.exir.dialects._ops import ops as exir_ops
2324
from torch.fx.passes.infra.pass_base import PassBase, PassResult
2425
from torch.utils import _pytree as pytree
@@ -158,3 +159,137 @@ def test_simplify_slice_op_args(self) -> None:
158159
modified_slice_copy = list(gm.graph.nodes)[1]
159160
self.assertEqual(modified_slice_copy.args[1:], (1, None, 3, 1))
160161
self.assertEqual(modified_slice_copy.kwargs, {})
162+
163+
164+
class TestBypassRedundantSliceChainPass(unittest.TestCase):
165+
def test_basic_chain_bypass(self) -> None:
166+
"""conv → slice(dim=3, 0:78) → slice(dim=3, 0:60) → direct slice(dim=3, 0:60)."""
167+
builder = GraphBuilder()
168+
x = builder.placeholder("x", torch.randn(2, 3, 4, 80))
169+
parent = builder.call_operator(
170+
op=exir_ops.edge.aten.slice_copy.Tensor,
171+
args=(x, 3, 0, 78, 1),
172+
)
173+
child = builder.call_operator(
174+
op=exir_ops.edge.aten.slice_copy.Tensor,
175+
args=(parent, 3, 0, 60, 1),
176+
)
177+
builder.output([child])
178+
original = builder.get_graph_module()
179+
180+
result = transform_and_check_numerics(
181+
original,
182+
(torch.randn(2, 3, 4, 80),),
183+
BypassRedundantSliceChainPass(),
184+
"BypassRedundantSliceChainPass",
185+
)
186+
self.assertTrue(result.modified)
187+
188+
# Parent slice should be dead-code-eliminated since child no longer uses it.
189+
self.assertEqual(
190+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
191+
)
192+
193+
def test_chain_with_offset(self) -> None:
194+
"""slice(dim=1, 10:50) → slice(dim=1, 5:20) → direct slice(dim=1, 15:30)."""
195+
builder = GraphBuilder()
196+
x = builder.placeholder("x", torch.randn(4, 64))
197+
parent = builder.call_operator(
198+
op=exir_ops.edge.aten.slice_copy.Tensor,
199+
args=(x, 1, 10, 50, 1),
200+
)
201+
child = builder.call_operator(
202+
op=exir_ops.edge.aten.slice_copy.Tensor,
203+
args=(parent, 1, 5, 20, 1),
204+
)
205+
builder.output([child])
206+
original = builder.get_graph_module()
207+
208+
result = transform_and_check_numerics(
209+
original,
210+
(torch.randn(4, 64),),
211+
BypassRedundantSliceChainPass(),
212+
"BypassRedundantSliceChainPass",
213+
)
214+
self.assertTrue(result.modified)
215+
self.assertEqual(
216+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
217+
)
218+
219+
def test_parent_kept_with_other_users(self) -> None:
220+
"""Parent slice has another user besides the child → parent stays."""
221+
builder = GraphBuilder()
222+
x = builder.placeholder("x", torch.randn(2, 3, 4, 80))
223+
parent = builder.call_operator(
224+
op=exir_ops.edge.aten.slice_copy.Tensor,
225+
args=(x, 3, 0, 78, 1),
226+
)
227+
child = builder.call_operator(
228+
op=exir_ops.edge.aten.slice_copy.Tensor,
229+
args=(parent, 3, 0, 60, 1),
230+
)
231+
neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(parent,))
232+
builder.output([child, neg])
233+
original = builder.get_graph_module()
234+
235+
result = transform_and_check_numerics(
236+
original,
237+
(torch.randn(2, 3, 4, 80),),
238+
BypassRedundantSliceChainPass(),
239+
"BypassRedundantSliceChainPass",
240+
)
241+
self.assertTrue(result.modified)
242+
# Parent kept (has neg user), child bypassed → 2 slices remain.
243+
self.assertEqual(
244+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 2
245+
)
246+
247+
def test_different_dims_no_change(self) -> None:
248+
"""Chained slices on different dims → no change."""
249+
builder = GraphBuilder()
250+
x = builder.placeholder("x", torch.randn(8, 16, 32))
251+
parent = builder.call_operator(
252+
op=exir_ops.edge.aten.slice_copy.Tensor,
253+
args=(x, 1, 0, 10, 1),
254+
)
255+
child = builder.call_operator(
256+
op=exir_ops.edge.aten.slice_copy.Tensor,
257+
args=(parent, 2, 0, 5, 1),
258+
)
259+
builder.output([child])
260+
original = builder.get_graph_module()
261+
262+
result = cast(PassResult, BypassRedundantSliceChainPass()(original))
263+
self.assertFalse(result.modified)
264+
265+
def test_step_not_one_no_change(self) -> None:
266+
"""Parent has step != 1 → no change."""
267+
builder = GraphBuilder()
268+
x = builder.placeholder("x", torch.randn(4, 64))
269+
parent = builder.call_operator(
270+
op=exir_ops.edge.aten.slice_copy.Tensor,
271+
args=(x, 1, 0, 60, 2),
272+
)
273+
child = builder.call_operator(
274+
op=exir_ops.edge.aten.slice_copy.Tensor,
275+
args=(parent, 1, 0, 10, 1),
276+
)
277+
builder.output([child])
278+
original = builder.get_graph_module()
279+
280+
result = cast(PassResult, BypassRedundantSliceChainPass()(original))
281+
self.assertFalse(result.modified)
282+
283+
def test_no_chain_no_change(self) -> None:
284+
"""Single slice with no slice user → no change."""
285+
builder = GraphBuilder()
286+
x = builder.placeholder("x", torch.randn(4, 64))
287+
sliced = builder.call_operator(
288+
op=exir_ops.edge.aten.slice_copy.Tensor,
289+
args=(x, 1, 0, 32, 1),
290+
)
291+
builder.output([sliced])
292+
original = builder.get_graph_module()
293+
294+
result = cast(PassResult, BypassRedundantSliceChainPass()(original))
295+
self.assertFalse(result.modified)

0 commit comments

Comments
 (0)