diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index 6077d51b099..7ac7e094a97 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -8,7 +8,7 @@ import torch from executorch.backends.qualcomm.builders.node_visitor import dq_ops, q_ops -from executorch.backends.qualcomm.builders.utils import get_parameter +from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter from executorch.backends.qualcomm.utils.constants import ( QCOM_DTYPE, QCOM_ENCODING, @@ -130,7 +130,7 @@ def _annotate_quant_attrs( self._annotate_requant(n) # With fold_quant enabled, check if the input of dq op is quantized param. param = None - if n.target in dq_ops: + if n.target in dq_ops and is_parameter(n.args[0], self.edge_program): param = get_parameter(n.args[0], self.edge_program) if n.target not in q_ops and param is None: continue diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index a51056eb7bb..abb6bfd657f 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -11,7 +11,12 @@ import numpy as np import torch -from executorch.backends.qualcomm.utils.constants import QCOM_DATA +from executorch.backends.qualcomm.utils.constants import ( + QCOM_DATA, + QCOM_QUANT_ATTRS, + QCOM_ZERO_POINT, +) +from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor @@ -31,6 +36,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], ) -> PyQnnManager.PyQnnOpWrapper: + # args of node : ['input', 'normalized_shape', 'weight', 'bias', 'eps'] input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( @@ -54,8 +60,26 @@ def define_node( axis = [len(input_tensor.shape) - 1] axis_shape = [len(axis)] - weight_node = self.get_node(node.args[2]) - weight_tensor = get_parameter(weight_node, self.edge_program) + has_weight = len(node.args) > 2 and node.args[2] is not None + if has_weight: + weight_node = self.get_node(node.args[2]) + assert weight_node is not None + weight_tensor = get_parameter(weight_node, self.edge_program) + else: + # elementwise_affine=False: use all-ones weight as identity + weight_tensor = torch.ones(normalized_shapes, dtype=torch.float32) + weight_node = torch.fx.Node( + node.graph, + node.name + "_runtime_weight", + "call_function", + exir_ops.edge.aten.tensor.default, + (), + {}, + ) + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + quant_attrs[QCOM_ZERO_POINT] = 0 + weight_node.meta[QCOM_QUANT_ATTRS] = quant_attrs weight_tensor_wrapper = self.define_tensor( weight_node, node, @@ -66,8 +90,10 @@ def define_node( layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] - bias_node = self.get_node(node.args[3]) - if bias_node is not None: + has_bias = len(node.args) > 3 and node.args[3] is not None + if has_bias: + bias_node = self.get_node(node.args[3]) + assert bias_node is not None bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( bias_node, @@ -78,7 +104,7 @@ def define_node( ) layer_norm_input_tensors.append(bias_tensor_wrapper) - epsilon = node.args[4] + epsilon = node.args[4] if len(node.args) > 4 else 1e-05 output_tensor = self.get_tensor(node, node, 0) output_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index 3345f2e1fc9..7e3b0aaa5f3 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -37,10 +37,12 @@ def get_parameter( param = get_buffer(edge_program, node) if is_lifted_tensor_constant(edge_program, node): param = get_lifted_tensor_constant(edge_program, node) - if param is not None: - # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32) - assert isinstance(param, torch.Tensor), "Expect parameter to be tensor" - param = param.type(node.meta["val"].dtype) + assert ( + param is not None + ), f"Expect {node.name} to be parameter, buffer, or lifted tensor constant" + # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32) + assert isinstance(param, torch.Tensor), "Expect parameter to be tensor" + param = param.type(node.meta["val"].dtype) return param diff --git a/backends/qualcomm/export_utils.py b/backends/qualcomm/export_utils.py index 2c7ab2abd02..f66bf9d5858 100644 --- a/backends/qualcomm/export_utils.py +++ b/backends/qualcomm/export_utils.py @@ -617,6 +617,7 @@ def build_executorch_binary( with open(pte_name, "wb") as file: exec_prog_mgr.write_to_file(file) + print(f"Successfully generated {pte_name}.") if qnn_config.compile_only: sys.exit(0) diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index cd65d02c752..35707f9bb1b 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -828,16 +828,15 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( - [torch.ops.aten.layer_norm.default], QnnConstants.OpLayerNorm.op_name + [torch.ops.aten.layer_norm.default, torch.ops.aten.native_layer_norm.default], + QnnConstants.OpLayerNorm.op_name, ) class LayerNorm(GeneralOpDef): @staticmethod def annotate(node: Node, quantization_config: QuantizationConfig) -> None: act_node = node.args[0] - weight_node = node.args[2] - bias_node = None - if len(node.args) > 2: - bias_node = node.args[3] + weight_node = node.args[2] if len(node.args) > 2 else None + bias_node = node.args[3] if len(node.args) > 3 else None if _is_annotated([node]): return @@ -848,20 +847,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: act_node, input_act_qspec, ) - if input_act_qspec.dtype == torch.int32: - annotate_input_qspec_map( - node, - weight_node, - get_16a16w_qnn_ptq_config().weight, - ) - else: - annotate_input_qspec_map( - node, - weight_node, - input_act_qspec, - ) - nodes_to_mark_annotated = [node, weight_node] - if bias_node: + nodes_to_mark_annotated = [node] + if isinstance(weight_node, Node): + if input_act_qspec.dtype == torch.int32: + annotate_input_qspec_map( + node, + weight_node, + get_16a16w_qnn_ptq_config().weight, + ) + else: + annotate_input_qspec_map( + node, + weight_node, + input_act_qspec, + ) + nodes_to_mark_annotated.append(weight_node) + if isinstance(bias_node, Node): annotate_input_qspec_map( node, bias_node, diff --git a/backends/qualcomm/quantizer/annotators/lpai_rules.py b/backends/qualcomm/quantizer/annotators/lpai_rules.py index 60cebfcc5c0..16bc134e561 100644 --- a/backends/qualcomm/quantizer/annotators/lpai_rules.py +++ b/backends/qualcomm/quantizer/annotators/lpai_rules.py @@ -475,10 +475,8 @@ class LayerNorm(GeneralOpDef): @staticmethod def annotate(node: Node, quantization_config: QuantizationConfig) -> None: act_node = node.args[0] - weight_node = node.args[2] - bias_node = None - if len(node.args) > 2: - bias_node = node.args[3] + weight_node = node.args[2] if len(node.args) > 2 else None + bias_node = node.args[3] if len(node.args) > 3 else None if _is_annotated([node]): return @@ -489,20 +487,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None: act_node, input_act_qspec, ) - if input_act_qspec.dtype == torch.int32: - annotate_input_qspec_map( - node, - weight_node, - get_16a16w_qnn_ptq_config().weight, - ) - else: - annotate_input_qspec_map( - node, - weight_node, - input_act_qspec, - ) - nodes_to_mark_annotated = [node, weight_node] - if bias_node: + nodes_to_mark_annotated = [node] + if isinstance(weight_node, Node): + if input_act_qspec.dtype == torch.int32: + annotate_input_qspec_map( + node, + weight_node, + get_16a16w_qnn_ptq_config().weight, + ) + else: + annotate_input_qspec_map( + node, + weight_node, + input_act_qspec, + ) + nodes_to_mark_annotated.append(weight_node) + if isinstance(bias_node, Node): annotate_input_qspec_map( node, bias_node, diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index b0120dd2848..3e797ffc93f 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1418,9 +1418,14 @@ def forward(self, x): class LayerNorm(torch.nn.Module): - def __init__(self, bias=True): + def __init__(self, elementwise_affine=True, bias=True): super().__init__() - self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6, bias=bias) + self.layer_norm = torch.nn.LayerNorm( + [768], + eps=1e-6, + elementwise_affine=elementwise_affine, + bias=bias, + ) self.linear = torch.nn.Linear(768, 196) def forward(self, x): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 6c8593eb755..7764130b3b5 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1385,7 +1385,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_layer_norm(self): - modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405 + modules = [ + LayerNorm(), # noqa: F405 + LayerNorm(bias=False), # noqa: F405 + LayerNorm(elementwise_affine=False), # noqa: F405 + ] sample_input = (torch.randn(196, 768),) for i, module in enumerate(modules): with self.subTest(i=i): @@ -3871,7 +3875,11 @@ def test_qnn_backend_up_sampling_nearest_2d_with_size(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_layer_norm(self): - modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405 + modules = [ + LayerNorm(), # noqa: F405 + LayerNorm(bias=False), # noqa: F405 + LayerNorm(elementwise_affine=False), # noqa: F405 + ] sample_input = (torch.randn(196, 768),) for i, module in enumerate(modules): with self.subTest(i=i):