Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions backends/qualcomm/builders/op_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
from executorch.backends.qualcomm.utils.constants import (
QCOM_DATA,
QCOM_QUANT_ATTRS,
QCOM_ZERO_POINT,
)
from executorch.exir.dialects._ops import ops as exir_ops

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
Expand All @@ -31,6 +36,7 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
# args of node : ['input', 'normalized_shape', 'weight', 'bias', 'eps']
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
Expand All @@ -54,8 +60,25 @@ def define_node(
axis = [len(input_tensor.shape) - 1]
axis_shape = [len(axis)]

weight_node = self.get_node(node.args[2])
weight_tensor = get_parameter(weight_node, self.edge_program)
has_weight = len(node.args) > 2 and node.args[2] is not None
if has_weight:
weight_node = self.get_node(node.args[2])
weight_tensor = get_parameter(weight_node, self.edge_program)
else:
# elementwise_affine=False: use all-ones weight as identity
weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32)
weight_node = torch.fx.Node(
node.graph,
node.name + "_runtime_weight",
"call_function",
exir_ops.edge.aten.tensor.default,
(),
{},
)
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
quant_attrs[QCOM_ZERO_POINT] = 0
weight_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
weight_tensor_wrapper = self.define_tensor(
weight_node,
node,
Expand All @@ -64,21 +87,34 @@ def define_node(
nodes_to_wrappers,
)

layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper]

bias_node = self.get_node(node.args[3])
if bias_node is not None:
# Fake node: even when original bias is absent, QNN still needs it
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the bias is optional for QNN and can be kept as in the original design.
https://docs.qualcomm.com/doc/80-63442-10/topic/MasterOpDef.html#layernorm

has_bias = len(node.args) > 3 and node.args[3] is not None
if has_bias:
bias_node = self.get_node(node.args[3])
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
bias_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
else:
bias_tensor = torch.zeros(normalized_shapes, dtype=torch.float32)
bias_node = torch.fx.Node(
node.graph,
node.name + "_runtime_bias",
"call_function",
exir_ops.edge.aten.tensor.default,
(),
{},
)
layer_norm_input_tensors.append(bias_tensor_wrapper)
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
quant_attrs[QCOM_ZERO_POINT] = 0
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
bias_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
)

epsilon = node.args[4]
epsilon = node.args[4] if len(node.args) > 4 else 1e-05

output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
Expand All @@ -94,7 +130,9 @@ def define_node(
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpLayerNorm.op_name,
)
layer_norm_op.AddInputTensors(layer_norm_input_tensors)
layer_norm_op.AddInputTensors(
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
)
layer_norm_op.AddOutputTensors([output_tensor_wrapper])
layer_norm_op.AddScalarParam(
OpLayerNorm.param_epsilon,
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def is_parameter(

def get_parameter(
node: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function shouldn't return None. Perhaps we should ensure that the node is not None before this function is called.

if node is None:
return None
param = None
if is_param(edge_program, node):
param = get_param(edge_program, node)
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def build_executorch_binary(
with open(pte_name, "wb") as file:
exec_prog_mgr.write_to_file(file)

print(f"Successfully generated {pte_name}.")
if qnn_config.compile_only:
sys.exit(0)

Expand Down
3 changes: 2 additions & 1 deletion backends/qualcomm/quantizer/annotators/htp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,8 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:


@register_annotator(
[torch.ops.aten.layer_norm.default], QnnConstants.OpLayerNorm.op_name
[torch.ops.aten.layer_norm.default, torch.ops.aten.native_layer_norm.default],
QnnConstants.OpLayerNorm.op_name,
)
class LayerNorm(GeneralOpDef):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/quantizer/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

def _mark_nodes_as_annotated(nodes: List[Node]):
for node in nodes:
if node is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to get rid of this, CC: @shewu-quic

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the node should not be None in this function.

continue
if Q_ANNOTATION_KEY not in node.meta:
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
node.meta[Q_ANNOTATION_KEY]._annotated = True
Expand Down
20 changes: 20 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,26 @@ def forward(self, x):
return self.linear(self.layer_norm(x))


class NativeLayerNorm(torch.nn.Module):
def __init__(self, affine=True):
super().__init__()
self.affine = affine
self.weight = torch.nn.Parameter(torch.ones(768))
self.bias = torch.nn.Parameter(torch.zeros(768))
self.normalized_shape = [768]
self.eps = 1e-6

def forward(self, x):
if self.affine:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two branches seem to be the same. Would it be possible to extend the current LayerNorm with torch.nn.LayerNorm(elementwise_affine=False) as a test case?

return torch.native_layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)[0]
else:
return torch.native_layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)[0]


class LayerNormAdd(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
15 changes: 15 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,13 @@ def test_qnn_backend_layer_norm(self):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_native_layer_norm(self):
modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405
sample_input = (torch.randn(196, 768),)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_leaky_relu(self):
torch.manual_seed(8)
test_comb = [
Expand Down Expand Up @@ -3811,6 +3818,14 @@ def test_qnn_backend_layer_norm(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_native_layer_norm(self):
modules = [NativeLayerNorm(), NativeLayerNorm(affine=False)] # noqa: F405
sample_input = (torch.randn(196, 768),)
for i, module in enumerate(modules):
with self.subTest(i=i):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_leaky_relu(self):
test_comb = [
{
Expand Down
Loading