Skip to content

Commit a050b21

Browse files
Update functional.py
update 2D padded format as well
1 parent 9674738 commit a050b21

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

areal/utils/functional/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def apply_rejection_sampling(
426426
if config.lower is not None:
427427
token_oor = token_oor | (token_ratio < config.lower)
428428
loss_mask = loss_mask * (~token_oor).to(loss_mask.dtype)
429-
behave_imp_weight = behave_imp_weight * (~token_oor).to(
429+
behave_imp_weight = token_ratio * (~token_oor).to(
430430
behave_imp_weight.dtype
431431
)
432432

0 commit comments

Comments
 (0)