-
Notifications
You must be signed in to change notification settings - Fork 208
Open
Description
Hi @williamSYSU,
Thanks for the work.
I found a potential bug in your seqgan rollout code.
In the below code snippet in function get_reward in utils/rollout.py,
rewards = torch.zeros([rollout_num * self.max_seq_len, batch_size]).float()
......
rewards = torch.mean(rewards.view(batch_size, self.max_seq_len, rollout_num), dim=-1)the reward tensor is reshaped from [rollout_num, max_seq_len, batch_size] to [batch_size, max_seq_len, rollout_num] and then (is expected to be) reduced at rollout_num. However, the tensor would have a different layout after the view as expected, which means the reduce would be performed erroneously.
To correct this error, I think there needs to be a transpose operation before view.
Looking forward to your reply.
Metadata
Metadata
Assignees
Labels
No labels