From 659654f95ddf40f477459c48c2305a83dfde14f1 Mon Sep 17 00:00:00 2001 From: Wei Xu Date: Mon, 22 Jul 2024 21:57:12 -0700 Subject: [PATCH] Fix convert_sync_batchnorm for timm timm (PyTorch Image Models) models may contain BatchNormAct2d, which is not supported by torch.nn.SyncBatchNorm. Need to use timm's convert_sync_batchnorm if timm is used. --- alf/trainers/policy_trainer.py | 3 +-- alf/utils/common.py | 9 +++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/alf/trainers/policy_trainer.py b/alf/trainers/policy_trainer.py index 7cd91e071..ade956cd7 100644 --- a/alf/trainers/policy_trainer.py +++ b/alf/trainers/policy_trainer.py @@ -592,8 +592,7 @@ def __init__(self, config: TrainerConfig, ddp_rank: int = -1): # Make sure the BN statistics of different processes are synced # https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm # This conversion needs to be performed before wrapping modules with DDP. - self._algorithm = torch.nn.SyncBatchNorm.convert_sync_batchnorm( - self._algorithm) + self._algorithm = common.convert_sync_batchnorm(self._algorithm) # Create a thread env to expose subprocess gin/alf configurations # which otherwise will be marked as "inoperative". Only created when diff --git a/alf/utils/common.py b/alf/utils/common.py index bdf7840e5..6b23ea027 100644 --- a/alf/utils/common.py +++ b/alf/utils/common.py @@ -1714,3 +1714,12 @@ def get_unused_port(start, end=65536, n=1): if process_locks: for process_lock in process_locks: process_lock.release() + + +try: + import timm + # timm models may contain BatchNormAct2d, which is not supported by torch.nn.SyncBatchNorm. + # Need to use timm's convert_sync_batchnorm. + convert_sync_batchnorm = timm.layers.convert_sync_batchnorm +except ImportError: + convert_sync_batchnorm = torch.nn.SyncBatchNorm.convert_sync_batchnorm