Skip to content

Add PSPO trust region method as alternative to clipping in GRPOTrainer#4548

Open
MCDwyer wants to merge 10 commits intohuggingface:mainfrom
MCDwyer:add-pspo-to-grpo
Open

Add PSPO trust region method as alternative to clipping in GRPOTrainer#4548
MCDwyer wants to merge 10 commits intohuggingface:mainfrom
MCDwyer:add-pspo-to-grpo

Conversation

@MCDwyer
Copy link

@MCDwyer MCDwyer commented Nov 19, 2025

What does this PR do?

Adds PSPO (Probability Smoothing Policy Optimisation) as an alternative trust-region method to GRPOTrainer. PSPO smooths probabilities toward the behaviour policy instead of using ratio clipping.

Paper: https://arxiv.org/abs/2509.21282

Changes:

  • Added trust_region_method parameter to GRPOConfig (default: "clip")
  • Added smoothing_alpha parameter for PSPO (default: 0.1)
  • Implemented PSPO smoothing in GRPOTrainer
  • Maintains backward compatibility

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@kashif

@kashif kashif self-assigned this Nov 19, 2025
@qgallouedec
Copy link
Member

qgallouedec commented Nov 21, 2025

thanks! can you apply the style (make precommit) and add short section in the paper index documentation page

and also, it would be nice to have a test case for this.

@MCDwyer
Copy link
Author

MCDwyer commented Nov 24, 2025

I've applied the style, added the paper documentation in docs/source/gr_pspo.md, and added a test in tests/test_grpo_trainer.py. Please let me know if there is anything else I need to do?

@qgallouedec
Copy link
Member

Oh, sorry, maybe I wasn't clear. You need to add a section to this part of the documentation:

https://github.com/huggingface/trl/blob/main/docs/source/paper_index.md

@MCDwyer
Copy link
Author

MCDwyer commented Nov 25, 2025

Sorry, I had misunderstood, thank you for clarifying. I've moved the documentation to a section in the paper_index.md

@MCDwyer
Copy link
Author

MCDwyer commented Jan 30, 2026

Hey, just checking in, is there anything else I need to do?

@casinca
Copy link
Contributor

casinca commented Feb 10, 2026

Hi, I found in a backlog of papers PSPO and while checking if it was already implemented, found your PR, nice!

When this gets merged I wonder if we could simplify/merge the logic with SAPO, since both are soft trust region methods.

@MCDwyer
Copy link
Author

MCDwyer commented Feb 16, 2026

Hi, yes I think one way to do this could be to move SAPO to a trust region method rather than as a loss, and have it that if SAPO is picked for the trust region method it uses the GRPO loss? I think that could disentangle the trust region methods from the loss types, whilst keeping the implementation of SAPO essentially the same?

I did this for a comparison as I was using an older trl version which didn't have the SAPO implementation yet, and I changed the _compute_loss to have this:

        coef_1 = torch.exp(log_importance_weights)

        if self.trust_region_method == 'clip':
            coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)

            # Two-sided clipping
            if self.args.delta is not None:
                coef_1 = torch.clamp(coef_1, max=self.args.delta)

            per_token_loss1 = coef_1 * advantages.unsqueeze(1)

            per_token_loss2 = coef_2 * advantages.unsqueeze(1)
            per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

        elif self.trust_region_method == "pspo":
            # smooth the ratio (equiv. to smoothing prob when smoothing toward behaviour policy)
            coef_2 = (1.0 - self.smoothing_alpha)*coef_1 + self.smoothing_alpha
            per_token_loss1 = None
            per_token_loss2 = None
            # always use smoothing for loss (no min)
            per_token_loss = -(coef_2 * advantages.unsqueeze(1))

        elif self.trust_region_method == "sapo":
            # note - this is copied from the newer trl repo, but moved to be where I change the coef methods. 
            # IF using this should have the loss type as "grpo" as they do?
            per_token_loss = torch.empty_like(coef_1)
            positive_advantages_mask = (advantages.unsqueeze(1) > 0).expand_as(coef_1)
            per_token_loss[positive_advantages_mask] = self.get_sapo_token_loss(
                coef_1[positive_advantages_mask], self.args.tau_pos
            )
            per_token_loss[~positive_advantages_mask] = self.get_sapo_token_loss(
                coef_1[~positive_advantages_mask], self.args.tau_neg
            )
            per_token_loss = -(per_token_loss * advantages.unsqueeze(1))

        else:
            raise AttributeError(f'Trust Region Method not clip, pspo or sapo: {self.trust_region_method}')

I could add this change into this PR if that would be useful?

@casinca
Copy link
Contributor

casinca commented Feb 16, 2026

When SAPO was implemented, it was the first soft trust region method in TRL GRPO, if I'm not mistaken, hence this choice to keep "per loss" I presume. But if more variants are added, that solely differ from a trust region standpoint (also seen this one #5027), I agree it might be good to re-evaluate at some point.

In any case, I don't want to give you extra work without approval. Ultimately, the decision is up to @kashif since he's assigned.

It's not a big deal anyway, I can always open a PR later, once this is merged.

Btw concerning your snippet, I simplified SAPO in this PR #4956 since.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants