diff --git a/alf/trainers/policy_trainer.py b/alf/trainers/policy_trainer.py index 7cd91e071..ade956cd7 100644 --- a/alf/trainers/policy_trainer.py +++ b/alf/trainers/policy_trainer.py @@ -592,8 +592,7 @@ def __init__(self, config: TrainerConfig, ddp_rank: int = -1): # Make sure the BN statistics of different processes are synced # https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm # This conversion needs to be performed before wrapping modules with DDP. - self._algorithm = torch.nn.SyncBatchNorm.convert_sync_batchnorm( - self._algorithm) + self._algorithm = common.convert_sync_batchnorm(self._algorithm) # Create a thread env to expose subprocess gin/alf configurations # which otherwise will be marked as "inoperative". Only created when diff --git a/alf/utils/common.py b/alf/utils/common.py index bdf7840e5..6b23ea027 100644 --- a/alf/utils/common.py +++ b/alf/utils/common.py @@ -1714,3 +1714,12 @@ def get_unused_port(start, end=65536, n=1): if process_locks: for process_lock in process_locks: process_lock.release() + + +try: + import timm + # timm models may contain BatchNormAct2d, which is not supported by torch.nn.SyncBatchNorm. + # Need to use timm's convert_sync_batchnorm. + convert_sync_batchnorm = timm.layers.convert_sync_batchnorm +except ImportError: + convert_sync_batchnorm = torch.nn.SyncBatchNorm.convert_sync_batchnorm