diff --git a/evojax/obs_norm.py b/evojax/obs_norm.py index 0078e0a2..13ca9a6a 100644 --- a/evojax/obs_norm.py +++ b/evojax/obs_norm.py @@ -46,6 +46,12 @@ def update_obs_params(obs_buffer: jnp.ndarray, obs_steps = obs_params[0] running_mean, running_var = jnp.split(obs_params[1:], 2) + + # reshape obs_params to support multi-dim observations + obs_shape = obs_buffer.shape[2:] # obs_buffer shape is [n_obs, pop_size, *obs_size] + running_mean = running_mean.reshape(obs_shape) + running_var = running_var.reshape(obs_shape) + if obs_mask.ndim != obs_buffer.ndim: obs_mask = obs_mask.reshape( obs_mask.shape + (1,) * (obs_buffer.ndim - obs_mask.ndim)) @@ -61,7 +67,7 @@ def update_obs_params(obs_buffer: jnp.ndarray, var_diff = jnp.sum(input_to_new_mean * input_to_old_mean, axis=(0, 1)) new_var = running_var + var_diff - return jnp.concatenate([jnp.ones(1) * total_steps, new_mean, new_var]) + return jnp.concatenate([jnp.ones(1) * total_steps, new_mean.flatten(), new_var.flatten()]) class ObsNormalizer(object):