Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,127 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
return False


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemovePermuteBeforeMeanPass(RemoveOrReplacePassInterface):
"""Remove or sink permute ops that precede mean reductions through unary chains.

When a permute feeds into a mean (possibly through unary ops like
dequantize/quantize), two optimizations apply:

1. If non-reduced dims maintain their relative order and positions, the
permute is fully removed and the mean's reduction dims are remapped.
2. Otherwise, the permute is moved after the mean so it operates on
smaller data.
"""

_UNARY_TARGETS: frozenset[EdgeOpOverload] = frozenset(
{
exir_ops.edge.cadence.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.abs.default,
}
)

@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.mean.dim]

def _find_permute_through_unary_chain(self, mean_node: Node) -> Optional[Node]:
"""Walk backward from mean through single-user unary ops to find a permute."""
current = mean_node.args[0]
if not isinstance(current, Node):
return None
while True:
if current.target == exir_ops.edge.aten.permute_copy.default:
return current
if current.target not in self._UNARY_TARGETS:
return None
if len(current.users) != 1:
return None
parent = current.args[0]
if not isinstance(parent, Node):
return None
current = parent

@staticmethod
def _get_keepdim(node: Node) -> bool:
if len(node.args) >= 3:
return bool(node.args[2])
return bool(node.kwargs.get("keepdim", False))

@staticmethod
def _can_fully_remove(
perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool
) -> bool:
"""Check whether the post-mean permute would be a no-op."""
canonical_reduction = {d % ndim for d in new_reduction_dims}
if keepdim:
return all(
perm[d] == d for d in range(ndim) if d not in canonical_reduction
)
non_reduced_in_perm_order = [d for d in perm if d not in canonical_reduction]
return non_reduced_in_perm_order == sorted(non_reduced_in_perm_order)

@staticmethod
def _compute_post_mean_perm(
perm: list[int], new_reduction_dims: list[int], ndim: int, keepdim: bool
) -> list[int]:
"""Compute the permutation to insert after the mean."""
if keepdim:
return list(perm)
canonical_reduction = {d % ndim for d in new_reduction_dims}
non_reduced_original = sorted(
d for d in range(ndim) if d not in canonical_reduction
)
non_reduced_permuted = [d for d in perm if d not in canonical_reduction]
return [non_reduced_original.index(d) for d in non_reduced_permuted]

def maybe_remove_or_replace(self, node: Node) -> bool:
reduction_dims = cast(list[int], node.args[1])

permute_node = self._find_permute_through_unary_chain(node)
if permute_node is None:
return False

perm = cast(list[int], permute_node.args[1])
ndim = len(perm)

if len(permute_node.users) != 1:
return False

permute_input = permute_node.args[0]
assert isinstance(permute_input, Node)

new_reduction_dims = [perm[d % ndim] for d in reduction_dims]
keepdim = self._get_keepdim(node)
can_remove = self._can_fully_remove(perm, new_reduction_dims, ndim, keepdim)

permute_node.replace_all_uses_with(permute_input)
node.args = (node.args[0], new_reduction_dims) + node.args[2:]

if not can_remove:
post_perm = self._compute_post_mean_perm(
perm, new_reduction_dims, ndim, keepdim
)
graph = node.graph
with graph.inserting_after(node):
new_permute = graph.create_node(
"call_function",
exir_ops.edge.aten.permute_copy.default,
args=(node, post_perm),
)
for user in list(node.users):
if user is not new_permute:
user.replace_input_with(node, new_permute)

return True


@register_cadence_pass(CadencePassAttribute(opt_level=2))
class RemovePermutesAroundElementwiseOps(_SharedRemovePermutesAroundElementwiseOps):
permutable_ops: set[EdgeOpOverload] = (
Expand Down Expand Up @@ -646,6 +767,7 @@ class CommonRemovePasses:
RemoveNopSliceOrViewOpPass,
RemoveToOpsPass,
RemoveZeroSizedCatArgsPass,
RemovePermuteBeforeMeanPass,
RemovePermutesAroundElementwiseOps,
FuseTransposeOrPermuteOpPairsPass,
RemoveSqueezeViewBeforeElementwiseOps,
Expand Down
195 changes: 195 additions & 0 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
RemoveNopLinalgVectorNormOpPass,
RemoveNopMulOpPass,
RemoveNopSliceOrViewOpPass,
RemovePermuteBeforeMeanPass,
RemovePermutesAroundElementwiseOps,
RemoveSqueezeViewBeforeElementwiseOps,
RemoveToOpsPass,
Expand Down Expand Up @@ -1013,3 +1014,197 @@ def test_remove_cat_from_slice_copy_second_input(self) -> None:

# Output should remain the same.
self.assertTrue(torch.equal(graph_module(*inputs)[0], expected_outputs))

def test_remove_permute_before_mean_fully_removed(self) -> None:
"""Permute → relu → mean where non-reduced dims preserve order → fully remove."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
relu = builder.call_operator(
op=exir_ops.edge.aten.relu.default, args=(permuted,)
)
mean = builder.call_operator(
op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False)
)
builder.output([mean])
original = builder.get_graph_module()
gm_before = copy.deepcopy(original)

graph_after = cast(
PassResult, RemovePermuteBeforeMeanPass()(original)
).graph_module

# Permute should be fully removed.
self.assertEqual(
count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0
)

# Mean reduction dims should be remapped to original space.
mean_nodes = graph_after.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.mean.dim
)
self.assertEqual(len(mean_nodes), 1)
self.assertEqual(mean_nodes[0].args[1], [2, 3])

validate(
gm_before,
graph_after,
(torch.randn(2, 3, 4, 5),),
"RemovePermuteBeforeMeanPass",
)

def test_remove_permute_before_mean_sunk_after(self) -> None:
"""Permute → relu → mean where non-reduced dims reorder → move permute after mean."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [2, 0, 3, 1])
)
relu = builder.call_operator(
op=exir_ops.edge.aten.relu.default, args=(permuted,)
)
mean = builder.call_operator(
op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False)
)
builder.output([mean])
original = builder.get_graph_module()
gm_before = copy.deepcopy(original)

graph_after = cast(
PassResult, RemovePermuteBeforeMeanPass()(original)
).graph_module

# One permute should remain (after the mean).
self.assertEqual(
count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1
)

# Mean reduction dims should be remapped.
mean_nodes = graph_after.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.mean.dim
)
self.assertEqual(len(mean_nodes), 1)
self.assertEqual(mean_nodes[0].args[1], [3, 1])

# The permute should come after the mean, not before.
permute_nodes = graph_after.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.permute_copy.default
)
self.assertEqual(permute_nodes[0].args[0], mean_nodes[0])
self.assertEqual(permute_nodes[0].args[1], [1, 0])

validate(
gm_before,
graph_after,
(torch.randn(2, 3, 4, 5),),
"RemovePermuteBeforeMeanPass",
)

def test_remove_permute_before_mean_keepdim_true(self) -> None:
"""Permute → relu → mean(keepdim=True) where only reduced dims shuffle → fully remove."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 1, 3, 2])
)
relu = builder.call_operator(
op=exir_ops.edge.aten.relu.default, args=(permuted,)
)
mean = builder.call_operator(
op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], True)
)
builder.output([mean])
original = builder.get_graph_module()
gm_before = copy.deepcopy(original)

graph_after = cast(
PassResult, RemovePermuteBeforeMeanPass()(original)
).graph_module

# Permute fully removed (only reduced dims were shuffled).
self.assertEqual(
count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 0
)

validate(
gm_before,
graph_after,
(torch.randn(2, 3, 4, 5),),
"RemovePermuteBeforeMeanPass",
)

def test_remove_permute_before_mean_keepdim_true_sunk(self) -> None:
"""Permute → relu → mean(keepdim=True) where non-reduced dims move → sink permute."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
relu = builder.call_operator(
op=exir_ops.edge.aten.relu.default, args=(permuted,)
)
mean = builder.call_operator(
op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], True)
)
builder.output([mean])
original = builder.get_graph_module()
gm_before = copy.deepcopy(original)

graph_after = cast(
PassResult, RemovePermuteBeforeMeanPass()(original)
).graph_module

# One permute should remain (sunk after mean).
self.assertEqual(
count_node(graph_after, exir_ops.edge.aten.permute_copy.default), 1
)

# The post-mean permute uses the original perm since keepdim=True.
permute_nodes = graph_after.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.permute_copy.default
)
self.assertEqual(permute_nodes[0].args[1], [0, 2, 3, 1])

validate(
gm_before,
graph_after,
(torch.randn(2, 3, 4, 5),),
"RemovePermuteBeforeMeanPass",
)

def test_remove_permute_before_mean_no_permute(self) -> None:
"""No permute before mean → no change."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
relu = builder.call_operator(op=exir_ops.edge.aten.relu.default, args=(x,))
mean = builder.call_operator(
op=exir_ops.edge.aten.mean.dim, args=(relu, [2, 3], False)
)
builder.output([mean])
original = builder.get_graph_module()

result = cast(PassResult, RemovePermuteBeforeMeanPass()(original))
self.assertFalse(result.modified)

def test_remove_permute_before_mean_multi_user(self) -> None:
"""Permute with multiple users → no change."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, 5))
permuted = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default, args=(x, [0, 2, 3, 1])
)
relu = builder.call_operator(
op=exir_ops.edge.aten.relu.default, args=(permuted,)
)
mean = builder.call_operator(
op=exir_ops.edge.aten.mean.dim, args=(relu, [1, 2], False)
)
# Second user of the permute prevents optimization.
neg = builder.call_operator(op=exir_ops.edge.aten.neg.default, args=(permuted,))
builder.output([mean, neg])
original = builder.get_graph_module()

result = cast(PassResult, RemovePermuteBeforeMeanPass()(original))
self.assertFalse(result.modified)
Loading