Skip to content

feat(rs): add two-stage Geo-RS + Token-MIS/TIS mode to RejectionSamplingConfig#1218

Draft
morgan-heisler wants to merge 24 commits intoinclusionAI:mainfrom
morgan-heisler:two-stage-sampling
Draft

feat(rs): add two-stage Geo-RS + Token-MIS/TIS mode to RejectionSamplingConfig#1218
morgan-heisler wants to merge 24 commits intoinclusionAI:mainfrom
morgan-heisler:two-stage-sampling

Conversation

@morgan-heisler
Copy link
Copy Markdown

Description

This PR implements the two-stage importance sampling mode proposed in the now-closed [PR #1084](#1084), which was referenced by
maintainer @garrett4wade as something that "should have been enclosed in #1088" when closing that PR.

The core idea is a two-stage rollout correction pipeline:

  • Stage 1 — Geometric Rejection Sampling (Geo-RS): Sequences whose geometric-mean importance ratio exp(mean(log(π_prox/π_behave))) exceeds upper are fully rejected. This is the existing level='sequence', metric='ratio' behaviour from feat(api): add unified RejectionSamplingConfig for async training #1088.
  • Stage 2 — Token-level correction (Token-MIS or Token-TIS): On the sequences that passed Stage 1, apply a per-token filter. token_action='mask' (Token-MIS) zeros out individual tokens where the token-level ratio still exceeds upper. token_action='clamp' (Token-TIS) clamps per-token weights to [lower, upper].

Why the two stages are both necessary

Experimental evidence from PR #1084 shows:

Method grad_norm approx_kl Stability
Token-MIS only ~0.5 (stable) Explodes Collapses ~250 steps
Geo-RS only Explodes (~15) Explodes Collapses quickly
Geo-RS + Token-MIS ~0.5 ~0 (stable) No collapse

The reason: pure Geo-RS without token-level weighting causes gradient-norm explosion because accepted sequences receive uniform weight 1.0 — no token-level credit assignment. And pure Token-MIS without Geo-RS pre-filtering lets divergent sequences through, causing KL runaway. The combination bounds sequence-level divergence and provides fine-grained token-level gradient signal.

This also achieves full feature parity with [VERL's five decoupled IS modes](https://verl.readthedocs.io/en/latest/algo/rollout_corr.html).

Key Changes

areal/api/cli_args.py

Added token_action: str | None field to RejectionSamplingConfig:

token_action: str | None = field(
    default=None,
    metadata={
        "help": (
            "Enables two-stage Geo-RS + Token-MIS/TIS mode. "
            "Only valid when level='sequence' and metric='ratio'. "
            "Stage 1 (Geo-RS): sequences whose geometric-mean ratio exceeds "
            "`upper` are fully rejected. "
            "Stage 2: on accepted sequences, apply per-token correction — "
            "'mask' (Token-MIS) or 'clamp' (Token-TIS). "
            "None disables Stage 2 (pure sequence-level Geo-RS)."
        ),
        "choices": ["mask", "clamp"],
    },
)

Extended __post_init__ validation to enforce:

  • token_action requires level='sequence'
  • token_action requires metric='ratio'
  • token_action requires action='mask' (hard sequence rejection as Stage 1)

areal/utils/functional/functional.py

Added Stage 2 logic at the end of apply_rejection_sampling(), after the existing sequence-level filtering block. The new code runs only when cfg.token_action is not None:

if cfg.token_action is not None and cfg.level == "sequence" and cfg.metric == "ratio":
    token_ratio = torch.exp(log_ratio)  # per-token π_prox / π_behave

    if cfg.token_action == "mask":
        # Token-MIS: zero tokens where per-token ratio > upper
        token_out_of_bounds = token_ratio > cfg.upper
        if cfg.lower is not None:
            token_out_of_bounds |= token_ratio < cfg.lower
        loss_mask = loss_mask * (~token_out_of_bounds).to(loss_mask.dtype)
        behave_imp_weight = behave_imp_weight * (~token_out_of_bounds).to(...)

    elif cfg.token_action == "clamp":
        # Token-TIS: clamp per-token ratio to [lower, upper]
        clamp_lower = cfg.lower if cfg.lower is not None else 0.0
        behave_imp_weight = token_ratio.clamp(min=clamp_lower, max=cfg.upper)

This is compatible with both the 2D padded and 1D packed tensor formats already handled by apply_rejection_sampling.

tests/test_rejection_sampling.py

Added TestTwoStageRejectionSampling class with 9 test cases:

  1. test_token_action_requires_sequence_level — config validation
  2. test_token_action_requires_ratio_metric — config validation
  3. test_token_action_requires_action_mask — config validation
  4. test_token_action_invalid_value — config validation
  5. test_valid_two_stage_mis_config — valid config constructs without error
  6. test_valid_two_stage_tis_config — valid config constructs without error
  7. test_two_stage_mis_rejects_divergent_sequences_2d — Stage 1 functional
  8. test_two_stage_mis_filters_high_token_ratio_within_accepted_seq — Stage 2 Token-MIS
  9. test_two_stage_tis_clamps_token_ratios_in_accepted_seq — Stage 2 Token-TIS
  10. test_stage1_rejection_dominates_stage2 — Stage 1 dominates even when Stage 2 would pass
  11. test_none_token_action_is_pure_sequence_geo_rstoken_action=None is identical to existing behaviour

Recommended Config

For off-policy async training where you want maximum stability:

actor:
  rejection_sampling:
    level: sequence
    action: mask
    metric: ratio
    agg: mean
    upper: 2.0
    token_action: mask   # Geo-RS + Token-MIS (recommended)

For a softer variant (keeps all tokens but bounds their weights):

actor:
  rejection_sampling:
    level: sequence
    action: mask
    metric: ratio
    agg: mean
    upper: 2.0
    lower: 0.5
    token_action: clamp  # Geo-RS + Token-TIS

Backward Compatibility

This is a non-breaking additive change. The new token_action field defaults to None, which preserves all existing behaviour exactly.

Related

Type of Change

  • New feature (non-breaking change that adds functionality)
  • Test coverage improvement

Checklist

  • I have read CONTRIBUTING.md
  • Pre-commit hooks pass
  • All existing unit tests pass
  • New tests added for new functionality (11 test cases)
  • No breaking changes (token_action defaults to None)
  • Branch is up to date with main

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a two-stage rejection sampling mechanism, adding a sequence-level rejection stage (Geo-RS) followed by a token-level 'mask' or 'clamp' action. This is implemented via a new token_action configuration parameter in RejectionSamplingConfig, with corresponding logic added to the functional sampling utilities and comprehensive validation tests. The review feedback highlights a logic discrepancy where the 'mask' mode incorrectly maintains sequence-level weights instead of per-token ratios, which undermines the goal of token-level credit assignment. Additionally, a performance optimization was suggested to reuse previously calculated token ratios instead of recomputing exponentials.

Comment thread areal/utils/functional/functional.py Outdated
Comment thread areal/utils/functional/functional.py Outdated
morgan-heisler and others added 3 commits April 20, 2026 15:45
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
update 2D padded format as well
add per_token_ratio
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant