Skip to content

Commit d16b705

Browse files
MaxGuigonThe gemma Authors
authored andcommitted
Add validation for attention type in the Attention module.
Use enum.auto() for AttentionType members. Add a ValueError check to ensure attn_type is either GLOBAL or LOCAL_SLIDING when processing attention masks. PiperOrigin-RevId: 871311809
1 parent bd0dabf commit d16b705

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

gemma/gm/nn/_modules.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def create_sliding_mask(
5353

5454

5555
class AttentionType(enum.Enum):
56-
GLOBAL = 1
57-
LOCAL_SLIDING = 2
56+
GLOBAL = enum.auto()
57+
LOCAL_SLIDING = enum.auto()
5858

5959

6060
class Embedder(nn.Module):
@@ -267,6 +267,11 @@ def __call__(
267267
)
268268
# [batch_size, seq_len, cache_size]
269269
attn_mask *= sliding_mask
270+
elif self.attn_type != AttentionType.GLOBAL:
271+
raise ValueError(
272+
'Attn_type must be either AttentionType.GLOBAL or'
273+
f' AttentionType.GLOBAL not {self.attn_type}'
274+
)
270275

271276
# [batch_size, seq_len, num_heads, cache_size]
272277
padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)

0 commit comments

Comments
 (0)