diff --git a/muzero/self_play/utils.py b/muzero/self_play/utils.py index ec49217..52045b9 100644 --- a/muzero/self_play/utils.py +++ b/muzero/self_play/utils.py @@ -53,7 +53,7 @@ def value(self) -> Optional[float]: def softmax_sample(visit_counts, actions, t): - counts_exp = np.exp(visit_counts) * (1 / t) + counts_exp = np.exp(visit_counts) ** (1 / t) probs = counts_exp / np.sum(counts_exp, axis=0) action_idx = np.random.choice(len(actions), p=probs) return actions[action_idx]