From e3c1bb79c576fd813a7ffde274f2a1eeb29e0598 Mon Sep 17 00:00:00 2001 From: GhostC <1276537536@qq.com> Date: Mon, 22 Dec 2025 21:00:28 +0800 Subject: [PATCH] Fix PPO's compute_reward() (#996) Signed-off-by Zehao Chen --- applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py b/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py index 22cba6be0..690639470 100644 --- a/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py @@ -184,7 +184,7 @@ def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score, kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs) rewards = kl_divergence_estimate start = prompts.shape[1] - 1 - ends = start + action_mask[:, start:].sum(1) + 1 + ends = start + action_mask[:, start:].sum(1) reward_clip = torch.clamp(reward_score, -self.clip_reward_value, self.clip_reward_value) batch_size = log_probs.shape[0] @@ -212,7 +212,7 @@ def train_rlhf(self, inputs): old_rewards = self.compute_rewards(prompts, log_probs, ref_log_probs, reward_score, action_mask) - ends = start + action_mask[:, start:].sum(1) + 1 + ends = start + action_mask[:, start:].sum(1) # we need to zero out the reward and value after the end of the conversation # otherwise the advantage/return will be wrong for i in range(old_rewards.shape[0]):