feat(rs): add two-stage Geo-RS + Token-MIS/TIS mode to RejectionSamplingConfig#1218
feat(rs): add two-stage Geo-RS + Token-MIS/TIS mode to RejectionSamplingConfig#1218morgan-heisler wants to merge 24 commits intoinclusionAI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a two-stage rejection sampling mechanism, adding a sequence-level rejection stage (Geo-RS) followed by a token-level 'mask' or 'clamp' action. This is implemented via a new token_action configuration parameter in RejectionSamplingConfig, with corresponding logic added to the functional sampling utilities and comprehensive validation tests. The review feedback highlights a logic discrepancy where the 'mask' mode incorrectly maintains sequence-level weights instead of per-token ratios, which undermines the goal of token-level credit assignment. Additionally, a performance optimization was suggested to reuse previously calculated token ratios instead of recomputing exponentials.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
update 2D padded format as well
add per_token_ratio
Description
This PR implements the two-stage importance sampling mode proposed in the now-closed [PR #1084](#1084), which was referenced by
maintainer
@garrett4wadeas something that "should have been enclosed in #1088" when closing that PR.The core idea is a two-stage rollout correction pipeline:
exp(mean(log(π_prox/π_behave)))exceedsupperare fully rejected. This is the existinglevel='sequence', metric='ratio'behaviour from feat(api): add unified RejectionSamplingConfig for async training #1088.token_action='mask'(Token-MIS) zeros out individual tokens where the token-level ratio still exceedsupper.token_action='clamp'(Token-TIS) clamps per-token weights to[lower, upper].Why the two stages are both necessary
Experimental evidence from PR #1084 shows:
The reason: pure Geo-RS without token-level weighting causes gradient-norm explosion because accepted sequences receive uniform weight
1.0— no token-level credit assignment. And pure Token-MIS without Geo-RS pre-filtering lets divergent sequences through, causing KL runaway. The combination bounds sequence-level divergence and provides fine-grained token-level gradient signal.This also achieves full feature parity with [VERL's five decoupled IS modes](https://verl.readthedocs.io/en/latest/algo/rollout_corr.html).
Key Changes
areal/api/cli_args.pyAdded
token_action: str | Nonefield toRejectionSamplingConfig:Extended
__post_init__validation to enforce:token_actionrequireslevel='sequence'token_actionrequiresmetric='ratio'token_actionrequiresaction='mask'(hard sequence rejection as Stage 1)areal/utils/functional/functional.pyAdded Stage 2 logic at the end of
apply_rejection_sampling(), after the existing sequence-level filtering block. The new code runs only whencfg.token_action is not None:This is compatible with both the 2D padded and 1D packed tensor formats already handled by
apply_rejection_sampling.tests/test_rejection_sampling.pyAdded
TestTwoStageRejectionSamplingclass with 9 test cases:test_token_action_requires_sequence_level— config validationtest_token_action_requires_ratio_metric— config validationtest_token_action_requires_action_mask— config validationtest_token_action_invalid_value— config validationtest_valid_two_stage_mis_config— valid config constructs without errortest_valid_two_stage_tis_config— valid config constructs without errortest_two_stage_mis_rejects_divergent_sequences_2d— Stage 1 functionaltest_two_stage_mis_filters_high_token_ratio_within_accepted_seq— Stage 2 Token-MIStest_two_stage_tis_clamps_token_ratios_in_accepted_seq— Stage 2 Token-TIStest_stage1_rejection_dominates_stage2— Stage 1 dominates even when Stage 2 would passtest_none_token_action_is_pure_sequence_geo_rs—token_action=Noneis identical to existing behaviourRecommended Config
For off-policy async training where you want maximum stability:
For a softer variant (keeps all tokens but bounds their weights):
Backward Compatibility
This is a non-breaking additive change. The new
token_actionfield defaults toNone, which preserves all existing behaviour exactly.Related
RejectionSamplingConfigintroduced there)Type of Change
Checklist