Skip to content

Commit 0a2c42c

Browse files
committed
Fix QNN LayerNorm optional arg annotation
1 parent 5beaa57 commit 0a2c42c

2 files changed

Lines changed: 36 additions & 36 deletions

File tree

backends/qualcomm/quantizer/annotators/htp_rules.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -835,10 +835,8 @@ class LayerNorm(GeneralOpDef):
835835
@staticmethod
836836
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
837837
act_node = node.args[0]
838-
weight_node = node.args[2]
839-
bias_node = None
840-
if len(node.args) > 2:
841-
bias_node = node.args[3]
838+
weight_node = node.args[2] if len(node.args) > 2 else None
839+
bias_node = node.args[3] if len(node.args) > 3 else None
842840

843841
if _is_annotated([node]):
844842
return
@@ -849,20 +847,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
849847
act_node,
850848
input_act_qspec,
851849
)
852-
if input_act_qspec.dtype == torch.int32:
853-
annotate_input_qspec_map(
854-
node,
855-
weight_node,
856-
get_16a16w_qnn_ptq_config().weight,
857-
)
858-
else:
859-
annotate_input_qspec_map(
860-
node,
861-
weight_node,
862-
input_act_qspec,
863-
)
864-
nodes_to_mark_annotated = [node, weight_node]
865-
if bias_node:
850+
nodes_to_mark_annotated = [node]
851+
if isinstance(weight_node, Node):
852+
if input_act_qspec.dtype == torch.int32:
853+
annotate_input_qspec_map(
854+
node,
855+
weight_node,
856+
get_16a16w_qnn_ptq_config().weight,
857+
)
858+
else:
859+
annotate_input_qspec_map(
860+
node,
861+
weight_node,
862+
input_act_qspec,
863+
)
864+
nodes_to_mark_annotated.append(weight_node)
865+
if isinstance(bias_node, Node):
866866
annotate_input_qspec_map(
867867
node,
868868
bias_node,

backends/qualcomm/quantizer/annotators/lpai_rules.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -475,10 +475,8 @@ class LayerNorm(GeneralOpDef):
475475
@staticmethod
476476
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
477477
act_node = node.args[0]
478-
weight_node = node.args[2]
479-
bias_node = None
480-
if len(node.args) > 2:
481-
bias_node = node.args[3]
478+
weight_node = node.args[2] if len(node.args) > 2 else None
479+
bias_node = node.args[3] if len(node.args) > 3 else None
482480

483481
if _is_annotated([node]):
484482
return
@@ -489,20 +487,22 @@ def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
489487
act_node,
490488
input_act_qspec,
491489
)
492-
if input_act_qspec.dtype == torch.int32:
493-
annotate_input_qspec_map(
494-
node,
495-
weight_node,
496-
get_16a16w_qnn_ptq_config().weight,
497-
)
498-
else:
499-
annotate_input_qspec_map(
500-
node,
501-
weight_node,
502-
input_act_qspec,
503-
)
504-
nodes_to_mark_annotated = [node, weight_node]
505-
if bias_node:
490+
nodes_to_mark_annotated = [node]
491+
if isinstance(weight_node, Node):
492+
if input_act_qspec.dtype == torch.int32:
493+
annotate_input_qspec_map(
494+
node,
495+
weight_node,
496+
get_16a16w_qnn_ptq_config().weight,
497+
)
498+
else:
499+
annotate_input_qspec_map(
500+
node,
501+
weight_node,
502+
input_act_qspec,
503+
)
504+
nodes_to_mark_annotated.append(weight_node)
505+
if isinstance(bias_node, Node):
506506
annotate_input_qspec_map(
507507
node,
508508
bias_node,

0 commit comments

Comments
 (0)