diff --git a/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py b/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py index 22cba6be0..690639470 100644 --- a/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py @@ -184,7 +184,7 @@ def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score, kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs) rewards = kl_divergence_estimate start = prompts.shape[1] - 1 - ends = start + action_mask[:, start:].sum(1) + 1 + ends = start + action_mask[:, start:].sum(1) reward_clip = torch.clamp(reward_score, -self.clip_reward_value, self.clip_reward_value) batch_size = log_probs.shape[0] @@ -212,7 +212,7 @@ def train_rlhf(self, inputs): old_rewards = self.compute_rewards(prompts, log_probs, ref_log_probs, reward_score, action_mask) - ends = start + action_mask[:, start:].sum(1) + 1 + ends = start + action_mask[:, start:].sum(1) # we need to zero out the reward and value after the end of the conversation # otherwise the advantage/return will be wrong for i in range(old_rewards.shape[0]):