Skip to content

Commit c1d482e

Browse files
authored
Merge back to back slices on the same dim
Differential Revision: D102425537 Pull Request resolved: #19128
1 parent 7e2ff8a commit c1d482e

2 files changed

Lines changed: 325 additions & 1 deletion

File tree

backends/cadence/aot/fuse_ops.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
HierarchicalInplacePassInterface,
3232
register_cadence_pass,
3333
RemoveOrReplacePassInterface,
34+
set_arg,
3435
)
3536
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
3637
from executorch.backends.transforms.fuse_cascaded_transpose_or_permute_ops import (
@@ -1003,6 +1004,75 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
10031004
return True
10041005

10051006

1007+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1008+
class FuseSliceSameDimPass(RemoveOrReplacePassInterface):
1009+
"""Fuse chained slices on the same dim into a single slice.
1010+
1011+
When a slice_copy's input is another slice_copy on the same dimension
1012+
with step=1, the child slice can read directly from the grandparent
1013+
with merged indices, eliminating the intermediate slice.
1014+
1015+
Handles negative start/end indices by canonicalizing them against the
1016+
relevant dimension size before merging.
1017+
"""
1018+
1019+
@staticmethod
1020+
def _canonicalize(val: int, dim_size: int) -> int:
1021+
return val + dim_size if val < 0 else val
1022+
1023+
@property
1024+
def targets(self) -> list[EdgeOpOverload]:
1025+
return [exir_ops.edge.aten.slice_copy.Tensor]
1026+
1027+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
1028+
parent = get_arg(node, "input", torch.fx.Node)
1029+
if parent.target != exir_ops.edge.aten.slice_copy.Tensor:
1030+
return False
1031+
1032+
grandparent = get_arg(parent, "input", torch.fx.Node)
1033+
ndim = len(grandparent.meta["val"].shape)
1034+
child_dim = get_arg(node, "dim", int) % ndim
1035+
parent_dim = get_arg(parent, "dim", int) % ndim
1036+
if child_dim != parent_dim:
1037+
return False
1038+
1039+
child_start = get_arg(node, "start", Optional[int])
1040+
child_end = get_arg(node, "end", Optional[int])
1041+
child_step = get_arg(node, "step", int)
1042+
parent_start = get_arg(parent, "start", Optional[int])
1043+
parent_end = get_arg(parent, "end", Optional[int])
1044+
parent_step = get_arg(parent, "step", int)
1045+
1046+
if child_step != 1 or parent_step != 1:
1047+
return False
1048+
if (
1049+
child_start is None
1050+
or child_end is None
1051+
or parent_start is None
1052+
or parent_end is None
1053+
):
1054+
return False
1055+
1056+
grandparent_dim_size = grandparent.meta["val"].shape[parent_dim]
1057+
parent_dim_size = parent.meta["val"].shape[parent_dim]
1058+
1059+
p_start = self._canonicalize(parent_start, grandparent_dim_size)
1060+
p_end = self._canonicalize(parent_end, grandparent_dim_size)
1061+
c_start = self._canonicalize(child_start, parent_dim_size)
1062+
c_end = self._canonicalize(child_end, parent_dim_size)
1063+
1064+
new_start = p_start + c_start
1065+
new_end = min(p_start + c_end, p_end)
1066+
1067+
if new_end > grandparent_dim_size:
1068+
return False
1069+
1070+
node.replace_input_with(parent, grandparent)
1071+
set_arg(node, "start", new_start)
1072+
set_arg(node, "end", new_end)
1073+
return True
1074+
1075+
10061076
class HierarchicalCSEPass(HierarchicalInplacePassInterface):
10071077
"""
10081078
A hierarchical Common Subexpression Elimination (CSE) pass that recursively
@@ -1035,4 +1105,5 @@ class CadenceFuseOpsInGraph:
10351105
FuseMulScalarIntoDequantPass,
10361106
FuseFullThenReshapePass,
10371107
FuseTransposeOrPermuteOpPairsPass,
1108+
FuseSliceSameDimPass,
10381109
]

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 254 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,15 @@
2525
FuseMulTensorIntoQuantPass,
2626
FuseQuantDequantToRequantizePass,
2727
FuseQuantizedBatchNormWithConv,
28+
FuseSliceSameDimPass,
2829
FuseTransposeOrPermuteOpPairsPass,
2930
HierarchicalCSEPass,
3031
)
31-
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
32+
from executorch.backends.cadence.aot.pass_utils import (
33+
count_node,
34+
get_arg,
35+
op_counts_match,
36+
)
3237
from executorch.backends.cadence.aot.typing_stubs import expand
3338
from executorch.backends.test.graph_builder import GraphBuilder
3439
from executorch.exir.dialects._ops import ops as exir_ops
@@ -1696,3 +1701,251 @@ def __init__(self) -> None:
16961701
# Verify fusion occurred: bn should be removed, conv remains
16971702
self.assertEqual(count_node(gm, conv_op), 1)
16981703
self.assertEqual(count_node(gm, bn_op), 0)
1704+
1705+
1706+
class TestFuseSliceSameDimPass(TestFusionPassesBase):
1707+
def _get_single_slice(self, gm: torch.fx.GraphModule) -> torch.fx.Node:
1708+
slices = gm.graph.find_nodes(
1709+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1710+
)
1711+
self.assertEqual(len(slices), 1)
1712+
return slices[0]
1713+
1714+
def test_basic_chain_bypass(self) -> None:
1715+
"""slice(dim=3, 0:78) → slice(dim=3, 0:60) → direct slice(dim=3, 0:60)."""
1716+
builder = GraphBuilder()
1717+
x = builder.placeholder("x", torch.randn(2, 3, 4, 80))
1718+
parent = builder.call_operator(
1719+
op=exir_ops.edge.aten.slice_copy.Tensor,
1720+
args=(x, 3, 0, 78, 1),
1721+
)
1722+
child = builder.call_operator(
1723+
op=exir_ops.edge.aten.slice_copy.Tensor,
1724+
args=(parent, 3, 0, 60, 1),
1725+
)
1726+
builder.output([child])
1727+
original = builder.get_graph_module()
1728+
gm_before = copy.deepcopy(original)
1729+
1730+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1731+
self.assertTrue(result.modified)
1732+
self.assertEqual(
1733+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
1734+
)
1735+
merged = self._get_single_slice(result.graph_module)
1736+
self.assertEqual(get_arg(merged, "start"), 0)
1737+
self.assertEqual(get_arg(merged, "end"), 60)
1738+
validate_numerics(
1739+
gm_before,
1740+
result.graph_module,
1741+
(torch.randn(2, 3, 4, 80),),
1742+
"FuseSliceSameDimPass",
1743+
)
1744+
1745+
def test_chain_with_offset(self) -> None:
1746+
"""slice(dim=1, 10:50) → slice(dim=1, 5:20) → direct slice(dim=1, 15:30)."""
1747+
builder = GraphBuilder()
1748+
x = builder.placeholder("x", torch.randn(4, 64))
1749+
parent = builder.call_operator(
1750+
op=exir_ops.edge.aten.slice_copy.Tensor,
1751+
args=(x, 1, 10, 50, 1),
1752+
)
1753+
child = builder.call_operator(
1754+
op=exir_ops.edge.aten.slice_copy.Tensor,
1755+
args=(parent, 1, 5, 20, 1),
1756+
)
1757+
builder.output([child])
1758+
original = builder.get_graph_module()
1759+
gm_before = copy.deepcopy(original)
1760+
1761+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1762+
self.assertTrue(result.modified)
1763+
self.assertEqual(
1764+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
1765+
)
1766+
merged = self._get_single_slice(result.graph_module)
1767+
self.assertEqual(get_arg(merged, "start"), 15)
1768+
self.assertEqual(get_arg(merged, "end"), 30)
1769+
validate_numerics(
1770+
gm_before,
1771+
result.graph_module,
1772+
(torch.randn(4, 64),),
1773+
"FuseSliceSameDimPass",
1774+
)
1775+
1776+
def test_parent_kept_with_other_users(self) -> None:
1777+
"""Parent slice has another user besides the child → parent stays."""
1778+
builder = GraphBuilder()
1779+
x = builder.placeholder("x", torch.randn(2, 3, 4, 80))
1780+
parent = builder.call_operator(
1781+
op=exir_ops.edge.aten.slice_copy.Tensor,
1782+
args=(x, 3, 0, 78, 1),
1783+
)
1784+
child = builder.call_operator(
1785+
op=exir_ops.edge.aten.slice_copy.Tensor,
1786+
args=(parent, 3, 0, 60, 1),
1787+
)
1788+
neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(parent,))
1789+
builder.output([child, neg])
1790+
original = builder.get_graph_module()
1791+
gm_before = copy.deepcopy(original)
1792+
1793+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1794+
self.assertTrue(result.modified)
1795+
self.assertEqual(
1796+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 2
1797+
)
1798+
slices = result.graph_module.graph.find_nodes(
1799+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
1800+
)
1801+
ends = sorted(get_arg(s, "end") for s in slices)
1802+
self.assertEqual(ends, [60, 78])
1803+
validate_numerics(
1804+
gm_before,
1805+
result.graph_module,
1806+
(torch.randn(2, 3, 4, 80),),
1807+
"FuseSliceSameDimPass",
1808+
)
1809+
1810+
def test_different_dims_no_change(self) -> None:
1811+
"""Chained slices on different dims → no change."""
1812+
builder = GraphBuilder()
1813+
x = builder.placeholder("x", torch.randn(8, 16, 32))
1814+
parent = builder.call_operator(
1815+
op=exir_ops.edge.aten.slice_copy.Tensor,
1816+
args=(x, 1, 0, 10, 1),
1817+
)
1818+
child = builder.call_operator(
1819+
op=exir_ops.edge.aten.slice_copy.Tensor,
1820+
args=(parent, 2, 0, 5, 1),
1821+
)
1822+
builder.output([child])
1823+
original = builder.get_graph_module()
1824+
1825+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1826+
self.assertFalse(result.modified)
1827+
1828+
def test_step_not_one_no_change(self) -> None:
1829+
"""Parent has step != 1 → no change."""
1830+
builder = GraphBuilder()
1831+
x = builder.placeholder("x", torch.randn(4, 64))
1832+
parent = builder.call_operator(
1833+
op=exir_ops.edge.aten.slice_copy.Tensor,
1834+
args=(x, 1, 0, 60, 2),
1835+
)
1836+
child = builder.call_operator(
1837+
op=exir_ops.edge.aten.slice_copy.Tensor,
1838+
args=(parent, 1, 0, 10, 1),
1839+
)
1840+
builder.output([child])
1841+
original = builder.get_graph_module()
1842+
1843+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1844+
self.assertFalse(result.modified)
1845+
1846+
def test_no_chain_no_change(self) -> None:
1847+
"""Single slice with no slice user → no change."""
1848+
builder = GraphBuilder()
1849+
x = builder.placeholder("x", torch.randn(4, 64))
1850+
sliced = builder.call_operator(
1851+
op=exir_ops.edge.aten.slice_copy.Tensor,
1852+
args=(x, 1, 0, 32, 1),
1853+
)
1854+
builder.output([sliced])
1855+
original = builder.get_graph_module()
1856+
1857+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1858+
self.assertFalse(result.modified)
1859+
1860+
def test_child_end_clamped_to_parent_range(self) -> None:
1861+
"""Child end exceeds parent output size → clamped to parent_end."""
1862+
builder = GraphBuilder()
1863+
x = builder.placeholder("x", torch.randn(1, 100))
1864+
parent = builder.call_operator(
1865+
op=exir_ops.edge.aten.slice_copy.Tensor,
1866+
args=(x, 1, 10, 50, 1),
1867+
)
1868+
child = builder.call_operator(
1869+
op=exir_ops.edge.aten.slice_copy.Tensor,
1870+
args=(parent, 1, 5, 45, 1),
1871+
)
1872+
builder.output([child])
1873+
original = builder.get_graph_module()
1874+
gm_before = copy.deepcopy(original)
1875+
1876+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1877+
self.assertTrue(result.modified)
1878+
self.assertEqual(
1879+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
1880+
)
1881+
merged = self._get_single_slice(result.graph_module)
1882+
self.assertEqual(get_arg(merged, "start"), 15)
1883+
self.assertEqual(get_arg(merged, "end"), 50)
1884+
validate_numerics(
1885+
gm_before,
1886+
result.graph_module,
1887+
(torch.randn(1, 100),),
1888+
"FuseSliceSameDimPass",
1889+
)
1890+
1891+
def test_negative_indices(self) -> None:
1892+
"""Negative start/end are canonicalized before merging."""
1893+
builder = GraphBuilder()
1894+
x = builder.placeholder("x", torch.randn(1, 100))
1895+
parent = builder.call_operator(
1896+
op=exir_ops.edge.aten.slice_copy.Tensor,
1897+
args=(x, 1, 10, -10, 1),
1898+
)
1899+
child = builder.call_operator(
1900+
op=exir_ops.edge.aten.slice_copy.Tensor,
1901+
args=(parent, 1, 5, -5, 1),
1902+
)
1903+
builder.output([child])
1904+
original = builder.get_graph_module()
1905+
gm_before = copy.deepcopy(original)
1906+
1907+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1908+
self.assertTrue(result.modified)
1909+
self.assertEqual(
1910+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
1911+
)
1912+
merged = self._get_single_slice(result.graph_module)
1913+
self.assertEqual(get_arg(merged, "start"), 15)
1914+
self.assertEqual(get_arg(merged, "end"), 85)
1915+
validate_numerics(
1916+
gm_before,
1917+
result.graph_module,
1918+
(torch.randn(1, 100),),
1919+
"FuseSliceSameDimPass",
1920+
)
1921+
1922+
def test_negative_dim(self) -> None:
1923+
"""Negative dim is canonicalized so matching works across conventions."""
1924+
builder = GraphBuilder()
1925+
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
1926+
parent = builder.call_operator(
1927+
op=exir_ops.edge.aten.slice_copy.Tensor,
1928+
args=(x, -1, 0, 4, 1),
1929+
)
1930+
child = builder.call_operator(
1931+
op=exir_ops.edge.aten.slice_copy.Tensor,
1932+
args=(parent, 3, 0, 2, 1),
1933+
)
1934+
builder.output([child])
1935+
original = builder.get_graph_module()
1936+
gm_before = copy.deepcopy(original)
1937+
1938+
result = cast(PassResult, FuseSliceSameDimPass()(original))
1939+
self.assertTrue(result.modified)
1940+
self.assertEqual(
1941+
count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
1942+
)
1943+
merged = self._get_single_slice(result.graph_module)
1944+
self.assertEqual(get_arg(merged, "start"), 0)
1945+
self.assertEqual(get_arg(merged, "end"), 2)
1946+
validate_numerics(
1947+
gm_before,
1948+
result.graph_module,
1949+
(torch.randn(2, 3, 4, 5),),
1950+
"FuseSliceSameDimPass",
1951+
)

0 commit comments

Comments
 (0)