|
25 | 25 | FuseMulTensorIntoQuantPass, |
26 | 26 | FuseQuantDequantToRequantizePass, |
27 | 27 | FuseQuantizedBatchNormWithConv, |
| 28 | + FuseSliceSameDimPass, |
28 | 29 | FuseTransposeOrPermuteOpPairsPass, |
29 | 30 | HierarchicalCSEPass, |
30 | 31 | ) |
31 | | -from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match |
| 32 | +from executorch.backends.cadence.aot.pass_utils import ( |
| 33 | + count_node, |
| 34 | + get_arg, |
| 35 | + op_counts_match, |
| 36 | +) |
32 | 37 | from executorch.backends.cadence.aot.typing_stubs import expand |
33 | 38 | from executorch.backends.test.graph_builder import GraphBuilder |
34 | 39 | from executorch.exir.dialects._ops import ops as exir_ops |
@@ -1696,3 +1701,251 @@ def __init__(self) -> None: |
1696 | 1701 | # Verify fusion occurred: bn should be removed, conv remains |
1697 | 1702 | self.assertEqual(count_node(gm, conv_op), 1) |
1698 | 1703 | self.assertEqual(count_node(gm, bn_op), 0) |
| 1704 | + |
| 1705 | + |
| 1706 | +class TestFuseSliceSameDimPass(TestFusionPassesBase): |
| 1707 | + def _get_single_slice(self, gm: torch.fx.GraphModule) -> torch.fx.Node: |
| 1708 | + slices = gm.graph.find_nodes( |
| 1709 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 1710 | + ) |
| 1711 | + self.assertEqual(len(slices), 1) |
| 1712 | + return slices[0] |
| 1713 | + |
| 1714 | + def test_basic_chain_bypass(self) -> None: |
| 1715 | + """slice(dim=3, 0:78) → slice(dim=3, 0:60) → direct slice(dim=3, 0:60).""" |
| 1716 | + builder = GraphBuilder() |
| 1717 | + x = builder.placeholder("x", torch.randn(2, 3, 4, 80)) |
| 1718 | + parent = builder.call_operator( |
| 1719 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1720 | + args=(x, 3, 0, 78, 1), |
| 1721 | + ) |
| 1722 | + child = builder.call_operator( |
| 1723 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1724 | + args=(parent, 3, 0, 60, 1), |
| 1725 | + ) |
| 1726 | + builder.output([child]) |
| 1727 | + original = builder.get_graph_module() |
| 1728 | + gm_before = copy.deepcopy(original) |
| 1729 | + |
| 1730 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1731 | + self.assertTrue(result.modified) |
| 1732 | + self.assertEqual( |
| 1733 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 1734 | + ) |
| 1735 | + merged = self._get_single_slice(result.graph_module) |
| 1736 | + self.assertEqual(get_arg(merged, "start"), 0) |
| 1737 | + self.assertEqual(get_arg(merged, "end"), 60) |
| 1738 | + validate_numerics( |
| 1739 | + gm_before, |
| 1740 | + result.graph_module, |
| 1741 | + (torch.randn(2, 3, 4, 80),), |
| 1742 | + "FuseSliceSameDimPass", |
| 1743 | + ) |
| 1744 | + |
| 1745 | + def test_chain_with_offset(self) -> None: |
| 1746 | + """slice(dim=1, 10:50) → slice(dim=1, 5:20) → direct slice(dim=1, 15:30).""" |
| 1747 | + builder = GraphBuilder() |
| 1748 | + x = builder.placeholder("x", torch.randn(4, 64)) |
| 1749 | + parent = builder.call_operator( |
| 1750 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1751 | + args=(x, 1, 10, 50, 1), |
| 1752 | + ) |
| 1753 | + child = builder.call_operator( |
| 1754 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1755 | + args=(parent, 1, 5, 20, 1), |
| 1756 | + ) |
| 1757 | + builder.output([child]) |
| 1758 | + original = builder.get_graph_module() |
| 1759 | + gm_before = copy.deepcopy(original) |
| 1760 | + |
| 1761 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1762 | + self.assertTrue(result.modified) |
| 1763 | + self.assertEqual( |
| 1764 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 1765 | + ) |
| 1766 | + merged = self._get_single_slice(result.graph_module) |
| 1767 | + self.assertEqual(get_arg(merged, "start"), 15) |
| 1768 | + self.assertEqual(get_arg(merged, "end"), 30) |
| 1769 | + validate_numerics( |
| 1770 | + gm_before, |
| 1771 | + result.graph_module, |
| 1772 | + (torch.randn(4, 64),), |
| 1773 | + "FuseSliceSameDimPass", |
| 1774 | + ) |
| 1775 | + |
| 1776 | + def test_parent_kept_with_other_users(self) -> None: |
| 1777 | + """Parent slice has another user besides the child → parent stays.""" |
| 1778 | + builder = GraphBuilder() |
| 1779 | + x = builder.placeholder("x", torch.randn(2, 3, 4, 80)) |
| 1780 | + parent = builder.call_operator( |
| 1781 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1782 | + args=(x, 3, 0, 78, 1), |
| 1783 | + ) |
| 1784 | + child = builder.call_operator( |
| 1785 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1786 | + args=(parent, 3, 0, 60, 1), |
| 1787 | + ) |
| 1788 | + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(parent,)) |
| 1789 | + builder.output([child, neg]) |
| 1790 | + original = builder.get_graph_module() |
| 1791 | + gm_before = copy.deepcopy(original) |
| 1792 | + |
| 1793 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1794 | + self.assertTrue(result.modified) |
| 1795 | + self.assertEqual( |
| 1796 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 2 |
| 1797 | + ) |
| 1798 | + slices = result.graph_module.graph.find_nodes( |
| 1799 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 1800 | + ) |
| 1801 | + ends = sorted(get_arg(s, "end") for s in slices) |
| 1802 | + self.assertEqual(ends, [60, 78]) |
| 1803 | + validate_numerics( |
| 1804 | + gm_before, |
| 1805 | + result.graph_module, |
| 1806 | + (torch.randn(2, 3, 4, 80),), |
| 1807 | + "FuseSliceSameDimPass", |
| 1808 | + ) |
| 1809 | + |
| 1810 | + def test_different_dims_no_change(self) -> None: |
| 1811 | + """Chained slices on different dims → no change.""" |
| 1812 | + builder = GraphBuilder() |
| 1813 | + x = builder.placeholder("x", torch.randn(8, 16, 32)) |
| 1814 | + parent = builder.call_operator( |
| 1815 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1816 | + args=(x, 1, 0, 10, 1), |
| 1817 | + ) |
| 1818 | + child = builder.call_operator( |
| 1819 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1820 | + args=(parent, 2, 0, 5, 1), |
| 1821 | + ) |
| 1822 | + builder.output([child]) |
| 1823 | + original = builder.get_graph_module() |
| 1824 | + |
| 1825 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1826 | + self.assertFalse(result.modified) |
| 1827 | + |
| 1828 | + def test_step_not_one_no_change(self) -> None: |
| 1829 | + """Parent has step != 1 → no change.""" |
| 1830 | + builder = GraphBuilder() |
| 1831 | + x = builder.placeholder("x", torch.randn(4, 64)) |
| 1832 | + parent = builder.call_operator( |
| 1833 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1834 | + args=(x, 1, 0, 60, 2), |
| 1835 | + ) |
| 1836 | + child = builder.call_operator( |
| 1837 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1838 | + args=(parent, 1, 0, 10, 1), |
| 1839 | + ) |
| 1840 | + builder.output([child]) |
| 1841 | + original = builder.get_graph_module() |
| 1842 | + |
| 1843 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1844 | + self.assertFalse(result.modified) |
| 1845 | + |
| 1846 | + def test_no_chain_no_change(self) -> None: |
| 1847 | + """Single slice with no slice user → no change.""" |
| 1848 | + builder = GraphBuilder() |
| 1849 | + x = builder.placeholder("x", torch.randn(4, 64)) |
| 1850 | + sliced = builder.call_operator( |
| 1851 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1852 | + args=(x, 1, 0, 32, 1), |
| 1853 | + ) |
| 1854 | + builder.output([sliced]) |
| 1855 | + original = builder.get_graph_module() |
| 1856 | + |
| 1857 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1858 | + self.assertFalse(result.modified) |
| 1859 | + |
| 1860 | + def test_child_end_clamped_to_parent_range(self) -> None: |
| 1861 | + """Child end exceeds parent output size → clamped to parent_end.""" |
| 1862 | + builder = GraphBuilder() |
| 1863 | + x = builder.placeholder("x", torch.randn(1, 100)) |
| 1864 | + parent = builder.call_operator( |
| 1865 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1866 | + args=(x, 1, 10, 50, 1), |
| 1867 | + ) |
| 1868 | + child = builder.call_operator( |
| 1869 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1870 | + args=(parent, 1, 5, 45, 1), |
| 1871 | + ) |
| 1872 | + builder.output([child]) |
| 1873 | + original = builder.get_graph_module() |
| 1874 | + gm_before = copy.deepcopy(original) |
| 1875 | + |
| 1876 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1877 | + self.assertTrue(result.modified) |
| 1878 | + self.assertEqual( |
| 1879 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 1880 | + ) |
| 1881 | + merged = self._get_single_slice(result.graph_module) |
| 1882 | + self.assertEqual(get_arg(merged, "start"), 15) |
| 1883 | + self.assertEqual(get_arg(merged, "end"), 50) |
| 1884 | + validate_numerics( |
| 1885 | + gm_before, |
| 1886 | + result.graph_module, |
| 1887 | + (torch.randn(1, 100),), |
| 1888 | + "FuseSliceSameDimPass", |
| 1889 | + ) |
| 1890 | + |
| 1891 | + def test_negative_indices(self) -> None: |
| 1892 | + """Negative start/end are canonicalized before merging.""" |
| 1893 | + builder = GraphBuilder() |
| 1894 | + x = builder.placeholder("x", torch.randn(1, 100)) |
| 1895 | + parent = builder.call_operator( |
| 1896 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1897 | + args=(x, 1, 10, -10, 1), |
| 1898 | + ) |
| 1899 | + child = builder.call_operator( |
| 1900 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1901 | + args=(parent, 1, 5, -5, 1), |
| 1902 | + ) |
| 1903 | + builder.output([child]) |
| 1904 | + original = builder.get_graph_module() |
| 1905 | + gm_before = copy.deepcopy(original) |
| 1906 | + |
| 1907 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1908 | + self.assertTrue(result.modified) |
| 1909 | + self.assertEqual( |
| 1910 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 1911 | + ) |
| 1912 | + merged = self._get_single_slice(result.graph_module) |
| 1913 | + self.assertEqual(get_arg(merged, "start"), 15) |
| 1914 | + self.assertEqual(get_arg(merged, "end"), 85) |
| 1915 | + validate_numerics( |
| 1916 | + gm_before, |
| 1917 | + result.graph_module, |
| 1918 | + (torch.randn(1, 100),), |
| 1919 | + "FuseSliceSameDimPass", |
| 1920 | + ) |
| 1921 | + |
| 1922 | + def test_negative_dim(self) -> None: |
| 1923 | + """Negative dim is canonicalized so matching works across conventions.""" |
| 1924 | + builder = GraphBuilder() |
| 1925 | + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) |
| 1926 | + parent = builder.call_operator( |
| 1927 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1928 | + args=(x, -1, 0, 4, 1), |
| 1929 | + ) |
| 1930 | + child = builder.call_operator( |
| 1931 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1932 | + args=(parent, 3, 0, 2, 1), |
| 1933 | + ) |
| 1934 | + builder.output([child]) |
| 1935 | + original = builder.get_graph_module() |
| 1936 | + gm_before = copy.deepcopy(original) |
| 1937 | + |
| 1938 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1939 | + self.assertTrue(result.modified) |
| 1940 | + self.assertEqual( |
| 1941 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 1942 | + ) |
| 1943 | + merged = self._get_single_slice(result.graph_module) |
| 1944 | + self.assertEqual(get_arg(merged, "start"), 0) |
| 1945 | + self.assertEqual(get_arg(merged, "end"), 2) |
| 1946 | + validate_numerics( |
| 1947 | + gm_before, |
| 1948 | + result.graph_module, |
| 1949 | + (torch.randn(2, 3, 4, 5),), |
| 1950 | + "FuseSliceSameDimPass", |
| 1951 | + ) |
0 commit comments