|
26 | 26 | FuseMulTensorIntoQuantPass, |
27 | 27 | FuseQuantDequantToRequantizePass, |
28 | 28 | FuseQuantizedBatchNormWithConv, |
| 29 | + FuseSliceSameDimPass, |
29 | 30 | FuseTransposeOrPermuteOpPairsPass, |
30 | 31 | HierarchicalCSEPass, |
31 | 32 | ) |
32 | | -from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match |
| 33 | +from executorch.backends.cadence.aot.pass_utils import count_node, get_arg, op_counts_match |
33 | 34 | from executorch.backends.cadence.aot.typing_stubs import expand |
34 | 35 | from executorch.backends.test.graph_builder import GraphBuilder |
35 | 36 | from executorch.exir.dialects._ops import ops as exir_ops |
@@ -1862,3 +1863,256 @@ def test_reduce_single_dim(self) -> None: |
1862 | 1863 | (torch.randn(3, 4, 5),), |
1863 | 1864 | "FuseMeanKeepDimWithViewPass", |
1864 | 1865 | ) |
| 1866 | + |
| 1867 | + |
| 1868 | +class TestFuseSliceSameDimPass(TestFusionPassesBase): |
| 1869 | + def _get_single_slice( |
| 1870 | + self, gm: torch.fx.GraphModule |
| 1871 | + ) -> torch.fx.Node: |
| 1872 | + slices = gm.graph.find_nodes( |
| 1873 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 1874 | + ) |
| 1875 | + self.assertEqual(len(slices), 1) |
| 1876 | + return slices[0] |
| 1877 | + |
| 1878 | + def test_basic_chain_bypass(self) -> None: |
| 1879 | + """slice(dim=3, 0:78) → slice(dim=3, 0:60) → direct slice(dim=3, 0:60).""" |
| 1880 | + builder = GraphBuilder() |
| 1881 | + x = builder.placeholder("x", torch.randn(2, 3, 4, 80)) |
| 1882 | + parent = builder.call_operator( |
| 1883 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1884 | + args=(x, 3, 0, 78, 1), |
| 1885 | + ) |
| 1886 | + child = builder.call_operator( |
| 1887 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1888 | + args=(parent, 3, 0, 60, 1), |
| 1889 | + ) |
| 1890 | + builder.output([child]) |
| 1891 | + original = builder.get_graph_module() |
| 1892 | + gm_before = copy.deepcopy(original) |
| 1893 | + |
| 1894 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1895 | + self.assertTrue(result.modified) |
| 1896 | + self.assertEqual( |
| 1897 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 1898 | + ) |
| 1899 | + merged = self._get_single_slice(result.graph_module) |
| 1900 | + self.assertEqual(get_arg(merged, "start"), 0) |
| 1901 | + self.assertEqual(get_arg(merged, "end"), 60) |
| 1902 | + validate_numerics( |
| 1903 | + gm_before, |
| 1904 | + result.graph_module, |
| 1905 | + (torch.randn(2, 3, 4, 80),), |
| 1906 | + "FuseSliceSameDimPass", |
| 1907 | + ) |
| 1908 | + |
| 1909 | + def test_chain_with_offset(self) -> None: |
| 1910 | + """slice(dim=1, 10:50) → slice(dim=1, 5:20) → direct slice(dim=1, 15:30).""" |
| 1911 | + builder = GraphBuilder() |
| 1912 | + x = builder.placeholder("x", torch.randn(4, 64)) |
| 1913 | + parent = builder.call_operator( |
| 1914 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1915 | + args=(x, 1, 10, 50, 1), |
| 1916 | + ) |
| 1917 | + child = builder.call_operator( |
| 1918 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1919 | + args=(parent, 1, 5, 20, 1), |
| 1920 | + ) |
| 1921 | + builder.output([child]) |
| 1922 | + original = builder.get_graph_module() |
| 1923 | + gm_before = copy.deepcopy(original) |
| 1924 | + |
| 1925 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1926 | + self.assertTrue(result.modified) |
| 1927 | + self.assertEqual( |
| 1928 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 1929 | + ) |
| 1930 | + merged = self._get_single_slice(result.graph_module) |
| 1931 | + self.assertEqual(get_arg(merged, "start"), 15) |
| 1932 | + self.assertEqual(get_arg(merged, "end"), 30) |
| 1933 | + validate_numerics( |
| 1934 | + gm_before, |
| 1935 | + result.graph_module, |
| 1936 | + (torch.randn(4, 64),), |
| 1937 | + "FuseSliceSameDimPass", |
| 1938 | + ) |
| 1939 | + |
| 1940 | + def test_parent_kept_with_other_users(self) -> None: |
| 1941 | + """Parent slice has another user besides the child → parent stays.""" |
| 1942 | + builder = GraphBuilder() |
| 1943 | + x = builder.placeholder("x", torch.randn(2, 3, 4, 80)) |
| 1944 | + parent = builder.call_operator( |
| 1945 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1946 | + args=(x, 3, 0, 78, 1), |
| 1947 | + ) |
| 1948 | + child = builder.call_operator( |
| 1949 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1950 | + args=(parent, 3, 0, 60, 1), |
| 1951 | + ) |
| 1952 | + neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(parent,)) |
| 1953 | + builder.output([child, neg]) |
| 1954 | + original = builder.get_graph_module() |
| 1955 | + gm_before = copy.deepcopy(original) |
| 1956 | + |
| 1957 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1958 | + self.assertTrue(result.modified) |
| 1959 | + self.assertEqual( |
| 1960 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 2 |
| 1961 | + ) |
| 1962 | + slices = result.graph_module.graph.find_nodes( |
| 1963 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 1964 | + ) |
| 1965 | + ends = sorted(get_arg(s, "end") for s in slices) |
| 1966 | + self.assertEqual(ends, [60, 78]) |
| 1967 | + validate_numerics( |
| 1968 | + gm_before, |
| 1969 | + result.graph_module, |
| 1970 | + (torch.randn(2, 3, 4, 80),), |
| 1971 | + "FuseSliceSameDimPass", |
| 1972 | + ) |
| 1973 | + |
| 1974 | + def test_different_dims_no_change(self) -> None: |
| 1975 | + """Chained slices on different dims → no change.""" |
| 1976 | + builder = GraphBuilder() |
| 1977 | + x = builder.placeholder("x", torch.randn(8, 16, 32)) |
| 1978 | + parent = builder.call_operator( |
| 1979 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1980 | + args=(x, 1, 0, 10, 1), |
| 1981 | + ) |
| 1982 | + child = builder.call_operator( |
| 1983 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1984 | + args=(parent, 2, 0, 5, 1), |
| 1985 | + ) |
| 1986 | + builder.output([child]) |
| 1987 | + original = builder.get_graph_module() |
| 1988 | + |
| 1989 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 1990 | + self.assertFalse(result.modified) |
| 1991 | + |
| 1992 | + def test_step_not_one_no_change(self) -> None: |
| 1993 | + """Parent has step != 1 → no change.""" |
| 1994 | + builder = GraphBuilder() |
| 1995 | + x = builder.placeholder("x", torch.randn(4, 64)) |
| 1996 | + parent = builder.call_operator( |
| 1997 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 1998 | + args=(x, 1, 0, 60, 2), |
| 1999 | + ) |
| 2000 | + child = builder.call_operator( |
| 2001 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 2002 | + args=(parent, 1, 0, 10, 1), |
| 2003 | + ) |
| 2004 | + builder.output([child]) |
| 2005 | + original = builder.get_graph_module() |
| 2006 | + |
| 2007 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 2008 | + self.assertFalse(result.modified) |
| 2009 | + |
| 2010 | + def test_no_chain_no_change(self) -> None: |
| 2011 | + """Single slice with no slice user → no change.""" |
| 2012 | + builder = GraphBuilder() |
| 2013 | + x = builder.placeholder("x", torch.randn(4, 64)) |
| 2014 | + sliced = builder.call_operator( |
| 2015 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 2016 | + args=(x, 1, 0, 32, 1), |
| 2017 | + ) |
| 2018 | + builder.output([sliced]) |
| 2019 | + original = builder.get_graph_module() |
| 2020 | + |
| 2021 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 2022 | + self.assertFalse(result.modified) |
| 2023 | + |
| 2024 | + def test_child_end_clamped_to_parent_range(self) -> None: |
| 2025 | + """Child end exceeds parent output size → clamped to parent_end.""" |
| 2026 | + builder = GraphBuilder() |
| 2027 | + x = builder.placeholder("x", torch.randn(1, 100)) |
| 2028 | + parent = builder.call_operator( |
| 2029 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 2030 | + args=(x, 1, 10, 50, 1), |
| 2031 | + ) |
| 2032 | + child = builder.call_operator( |
| 2033 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 2034 | + args=(parent, 1, 5, 45, 1), |
| 2035 | + ) |
| 2036 | + builder.output([child]) |
| 2037 | + original = builder.get_graph_module() |
| 2038 | + gm_before = copy.deepcopy(original) |
| 2039 | + |
| 2040 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 2041 | + self.assertTrue(result.modified) |
| 2042 | + self.assertEqual( |
| 2043 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 2044 | + ) |
| 2045 | + merged = self._get_single_slice(result.graph_module) |
| 2046 | + self.assertEqual(get_arg(merged, "start"), 15) |
| 2047 | + self.assertEqual(get_arg(merged, "end"), 50) |
| 2048 | + validate_numerics( |
| 2049 | + gm_before, |
| 2050 | + result.graph_module, |
| 2051 | + (torch.randn(1, 100),), |
| 2052 | + "FuseSliceSameDimPass", |
| 2053 | + ) |
| 2054 | + |
| 2055 | + def test_negative_indices(self) -> None: |
| 2056 | + """Negative start/end are canonicalized before merging.""" |
| 2057 | + # Parent: slice(dim=1, 10:-10) on size 100 → [10:90], output size 80. |
| 2058 | + # Child: slice(dim=1, 5:-5) on size 80 → [5:75], output size 70. |
| 2059 | + # Merged: [15:85]. |
| 2060 | + builder = GraphBuilder() |
| 2061 | + x = builder.placeholder("x", torch.randn(1, 100)) |
| 2062 | + parent = builder.call_operator( |
| 2063 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 2064 | + args=(x, 1, 10, -10, 1), |
| 2065 | + ) |
| 2066 | + child = builder.call_operator( |
| 2067 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 2068 | + args=(parent, 1, 5, -5, 1), |
| 2069 | + ) |
| 2070 | + builder.output([child]) |
| 2071 | + original = builder.get_graph_module() |
| 2072 | + gm_before = copy.deepcopy(original) |
| 2073 | + |
| 2074 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 2075 | + self.assertTrue(result.modified) |
| 2076 | + self.assertEqual( |
| 2077 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 2078 | + ) |
| 2079 | + merged = self._get_single_slice(result.graph_module) |
| 2080 | + self.assertEqual(get_arg(merged, "start"), 15) |
| 2081 | + self.assertEqual(get_arg(merged, "end"), 85) |
| 2082 | + validate_numerics( |
| 2083 | + gm_before, |
| 2084 | + result.graph_module, |
| 2085 | + (torch.randn(1, 100),), |
| 2086 | + "FuseSliceSameDimPass", |
| 2087 | + ) |
| 2088 | + |
| 2089 | + def test_negative_dim(self) -> None: |
| 2090 | + """Negative dim is canonicalized so matching works across conventions.""" |
| 2091 | + builder = GraphBuilder() |
| 2092 | + x = builder.placeholder("x", torch.randn(2, 3, 4, 5)) |
| 2093 | + parent = builder.call_operator( |
| 2094 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 2095 | + args=(x, -1, 0, 4, 1), |
| 2096 | + ) |
| 2097 | + child = builder.call_operator( |
| 2098 | + op=exir_ops.edge.aten.slice_copy.Tensor, |
| 2099 | + args=(parent, 3, 0, 2, 1), |
| 2100 | + ) |
| 2101 | + builder.output([child]) |
| 2102 | + original = builder.get_graph_module() |
| 2103 | + gm_before = copy.deepcopy(original) |
| 2104 | + |
| 2105 | + result = cast(PassResult, FuseSliceSameDimPass()(original)) |
| 2106 | + self.assertTrue(result.modified) |
| 2107 | + self.assertEqual( |
| 2108 | + count_node(result.graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 |
| 2109 | + ) |
| 2110 | + merged = self._get_single_slice(result.graph_module) |
| 2111 | + self.assertEqual(get_arg(merged, "start"), 0) |
| 2112 | + self.assertEqual(get_arg(merged, "end"), 2) |
| 2113 | + validate_numerics( |
| 2114 | + gm_before, |
| 2115 | + result.graph_module, |
| 2116 | + (torch.randn(2, 3, 4, 5),), |
| 2117 | + "FuseSliceSameDimPass", |
| 2118 | + ) |
0 commit comments