Commit d16b705
Add validation for attention type in the Attention module.
Use enum.auto() for AttentionType members. Add a ValueError check to ensure attn_type is either GLOBAL or LOCAL_SLIDING when processing attention masks.
PiperOrigin-RevId: 8713118091 parent bd0dabf commit d16b705
1 file changed
Lines changed: 7 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
56 | | - | |
57 | | - | |
| 56 | + | |
| 57 | + | |
58 | 58 | | |
59 | 59 | | |
60 | 60 | | |
| |||
267 | 267 | | |
268 | 268 | | |
269 | 269 | | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
270 | 275 | | |
271 | 276 | | |
272 | 277 | | |
| |||
0 commit comments