-
Notifications
You must be signed in to change notification settings - Fork 971
[Qualcomm] Support native_layer_norm and affine-free LayerNorm in QNN backend #18990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
ee9a981
40404fd
5beaa57
0a2c42c
31ea426
acef7be
1a5a256
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,9 @@ def is_parameter( | |
|
|
||
| def get_parameter( | ||
| node: torch.fx.Node, edge_program: torch.export.ExportedProgram | ||
|
|
||
| ) -> torch.Tensor: | ||
| ) -> Optional[torch.Tensor]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function shouldn't return None. Perhaps we should ensure that the node is not None before this function is called. |
||
| if node is None: | ||
| return None | ||
| param = None | ||
| if is_param(edge_program, node): | ||
| param = get_param(edge_program, node) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,8 @@ | |
|
|
||
| def _mark_nodes_as_annotated(nodes: List[Node]): | ||
|
|
||
| for node in nodes: | ||
| if node is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might want to get rid of this, CC: @shewu-quic
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the node should not be None in this function. |
||
| continue | ||
|
|
||
| if Q_ANNOTATION_KEY not in node.meta: | ||
| node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation() | ||
| node.meta[Q_ANNOTATION_KEY]._annotated = True | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1388,6 +1388,26 @@ def forward(self, x): | |
| return self.linear(self.layer_norm(x)) | ||
|
|
||
|
|
||
| class NativeLayerNorm(torch.nn.Module): | ||
| def __init__(self, affine=True): | ||
| super().__init__() | ||
| self.affine = affine | ||
| self.weight = torch.nn.Parameter(torch.ones(768)) | ||
| self.bias = torch.nn.Parameter(torch.zeros(768)) | ||
| self.normalized_shape = [768] | ||
| self.eps = 1e-6 | ||
|
|
||
| def forward(self, x): | ||
| if self.affine: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two branches seem to be the same. Would it be possible to extend the current |
||
| return torch.native_layer_norm( | ||
| x, self.normalized_shape, self.weight, self.bias, self.eps | ||
| )[0] | ||
| else: | ||
| return torch.native_layer_norm( | ||
| x, self.normalized_shape, self.weight, self.bias, self.eps | ||
| )[0] | ||
|
|
||
|
|
||
|
|
||
| class LayerNormAdd(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the bias is optional for QNN and can be kept as in the original design.
https://docs.qualcomm.com/doc/80-63442-10/topic/MasterOpDef.html#layernorm