Skip to content

Conversation

@LeonEricsson
Copy link

@LeonEricsson LeonEricsson commented Feb 10, 2026

Motivation

Extend rejection sampling (RS) to support sequence-level mismatch-based rejection rules derived from tokenwise policy divergence proxies. SLIME already supports:

  • Importance sampling (IS): multiply the loss by an IS weight (token/sequence/geometric).
  • Rejection sampling (RS): mask tokens/sequences when IS weights fall outside configured bounds.

These RS mechanisms primarily control weight magnitude (i.e., they target high-variance outliers from extreme importance weights). In long-horizon settings, however, they may not be sufficient to detect harmful off-policy samples: per-token mismatches can cancel when aggregated, allowing problematic sequences to slip through.

Solution

Add trust-region RS criteria that reject entire sequences based on token-level KL mismatch aggregated over the trajectory:

  • Max mismatch: reject if any token violates the trust region (captures localized catastrophic divergence).
  • Mean mismatch: reject if the average mismatch over tokens is too large (captures globally off-policy sequences).

Following Trust Region Masking for Long-Horizon LLM Reinforcement Learning, this PR uses:

  • k2 for the max criterion (symmetric, detects divergence in both directions), and
  • k3 for the mean criterion (unbiased for the forward-KL target under rollout sampling and strictly non-negative, reducing cancellation effects).

Both criteria use an upper bound only (rs_tr_threshold). The max criterion replaces the existing veto mechanism as a more principled outlier detector.

Implementation

This PR extends the existing RS path with a trust-region option supporting max and mean criteria. The implementation closely follows the paper, but intentionally does not introduce a general “rule engine” for trust-region / RS composition.

A more general alternative would provide a flexible rule framework that can combine:

  • token statistics (e.g., k1/k2/k3), and
  • reductions (e.g., max/mean/sum),
    alongside existing IS-weight-based rejection.

Alternative implementation (rule-based RS)

As a potential follow-up, a configurable rule framework could look like this:

--rs-rule (repeatable and composable)

Format:
--rs-rule "name=;scope=<sequence|token>;stat=<k1|k2|k3>;reduce=<...>;low=<float?>;high=<float?>"

Required keys:

  • name: unique identifier
  • scope: must support sequence (token scope optional/future; if unsupported, error)
  • stat: token statistic; must support k1, k2, k3
  • reduce: token-to-scalar reduction; must support:
    • max
    • mean
    • sum
    • identity (required only if scope=token; optional otherwise)

In such a framework, the legacy RS modes naturally map to k1 with sum (sequence) or mean (geometric) reduction. This approach is inspired by this verl PR and docs.


The max (k2) and mean (k3) trust-region criteria implemented in this PR are likely the most important variants in practice. That said, a more flexible rule framework would make it easier to experiment with additional criteria and compositions as we iterate. I’m happy to forgo this PR and focus on a said flexible design instead. I opened this narrower change as concrete baseline to kick of a discussion.

TODO

  • Document trust region masking in train_infer_mismatch_helper/README.md

@LeonEricsson LeonEricsson changed the title Feature/rs trust region gate [Feature] Trust region rejection sampling Feb 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant