while using policy network with BatchNorm layer, getting following error:
ModifyScopeVariableError: Cannot update variable "mean" in "/bn_init" because collection "batch_stats" is immutable.
(https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ModifyScopeVariableError)