-
Notifications
You must be signed in to change notification settings - Fork 6
Description
We were trying to finetune a Matformer checkpoint ( MatFormer-OLMo-180M Link )
We used the following command to call the training script
python train.py ../configs/pile-tiny.yaml \
--matformer_factor=8 \
--matformer_factor=8 \
--model.d_model=512 \
--model.n_heads=16 \
--model.n_layers=8 \
--model.max_sequence_length=2048 \
--device_train_microbatch_size=8 \
--global_train_batch_size=128 \
--max_duration=75000 \
--optimizer.learning_rate=1.0e-3 \
--console_log_interval=10 \
--load_path=:"/raid/ganesh/namitha/Skill_localization_experiment/ckpt_paths/MatFormer-OLMo-180M" \
--run_name="matformer-olmo-180M-finetune"
where the folder mentioned in load_path is obtained by download from the link mentioned in the README for MatFormer-OLMo-180M .
However running this gives us the following error
[2024-04-18 09:09:04] CRITICAL [root, rank=0] Uncaught ValueError: Must flatten tensors on the same device but got both cuda:0 and meta
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:229 in <modul │
│ │
│ 226 │ │ raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]") │
│ 227 │ print([clean_opt(s) for s in args_list]) │
│ 228 │ cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list]) │
│ ❱ 229 │ main(cfg) │
│ 230 │
│ │
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:108 in main │
│ │
│ 105 │ log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embeddin │
│ 106 │ torch.distributed.init_process_group(backend='nccl',rank=0, world_size=1) │
│ 107 │ # Wrap the model in FSDP. │
│ ❱ 108 │ fsdp_model = FSDP( │
│ 109 │ │ olmo_model, │
│ 110 │ │ sharding_strategy=cfg.fsdp.sharding_strategy, │
│ 111 │ │ mixed_precision=MixedPrecision( # equivalent to MosaicML's "PURE" │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 474 │ │ │ │ # process groups. │
│ 475 │ │ │ │ root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) │
│ 476 │ │ │ │
│ ❱ 477 │ │ │ _auto_wrap( │
│ 478 │ │ │ │ module, │
│ 479 │ │ │ │ auto_wrap_policy, │
│ 480 │ │ │ │ self._ignored_modules, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 98 │ │ ) │
│ 99 │ │ recursive_wrap_kwargs["auto_wrap_policy"] = policy │
│ 100 │ │ _warn_on_overridden_mixed_precision(overridden_module_classes) │
│ ❱ 101 │ _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] │
│ 102 │
│ 103 │
│ 104 def _check_nested_wrapping(root_module: nn.Module): │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 540 │ │ for name, child in module.named_children(): │
│ 541 │ │ │ if child in ignored_modules: │
│ 542 │ │ │ │ continue │
│ ❱ 543 │ │ │ wrapped_child, num_wrapped_params = _recursive_wrap( │
│ 544 │ │ │ │ module=child, │
│ 545 │ │ │ │ auto_wrap_policy=auto_wrap_policy, │
│ 546 │ │ │ │ wrapper_cls=wrapper_cls, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 540 │ │ for name, child in module.named_children(): │
│ 541 │ │ │ if child in ignored_modules: │
│ 542 │ │ │ │ continue │
│ ❱ 543 │ │ │ wrapped_child, num_wrapped_params = _recursive_wrap( │
│ 544 │ │ │ │ module=child, │
│ 545 │ │ │ │ auto_wrap_policy=auto_wrap_policy, │
│ 546 │ │ │ │ wrapper_cls=wrapper_cls, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 540 │ │ for name, child in module.named_children(): │
│ 541 │ │ │ if child in ignored_modules: │
│ 542 │ │ │ │ continue │
│ ❱ 543 │ │ │ wrapped_child, num_wrapped_params = _recursive_wrap( │
│ 544 │ │ │ │ module=child, │
│ 545 │ │ │ │ auto_wrap_policy=auto_wrap_policy, │
│ 546 │ │ │ │ wrapper_cls=wrapper_cls, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 558 │ │ │ module=module, recurse=False, nonwrapped_numel=remainder │
│ 559 │ │ ): │
│ 560 │ │ │ # Leaf node or final wrapping of the remainder both happen here. │
│ ❱ 561 │ │ │ return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel │
│ 562 │ │ else: │
│ 563 │ │ │ return module, total_wrapped_numel │
│ 564 │ return module, 0 │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 487 │ │ overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type] │
│ 488 │ │ return wrapper_cls(module, **overrides) │
│ 489 │ │
│ ❱ 490 │ return wrapper_cls(module, **kwargs) │
│ 491 │
│ 492 │
│ 493 def _recursive_wrap( │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 500 │ │ _init_buffer_state(self, module) │
│ 501 │ │ # extension needs to be set before `_init_param_handle_from_module()` │
│ 502 │ │ _init_extension(self, device_mesh) │
│ ❱ 503 │ │ _init_param_handle_from_module( │
│ 504 │ │ │ self, │
│ 505 │ │ │ module, │
│ 506 │ │ │ device_id, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 587 │ │ │ _sync_module_params_and_buffers( │
│ 588 │ │ │ │ fully_sharded_module, managed_params, state._inter_node_pg │
│ 589 │ │ │ ) │
│ ❱ 590 │ _init_param_handle_from_params(state, managed_params, fully_sharded_module) │
│ 591 │ return state │
│ 592 │
│ 593 │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 599 ): │
│ 600 │ if len(params) == 0: │
│ 601 │ │ return │
│ ❱ 602 │ handle = FlatParamHandle( │
│ 603 │ │ params, │
│ 604 │ │ fully_sharded_module, │
│ 605 │ │ state.compute_device, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 570 │ │ │ else 0 │
│ 571 │ │ ) │
│ 572 │ │ self._fsdp_extension = fsdp_extension │
│ ❱ 573 │ │ self._init_flat_param_and_metadata( │
│ 574 │ │ │ params, fully_sharded_module, self._aligned_numel, use_orig_params # type: i │
│ 575 │ │ ) │
│ 576 │ │ self._use_unsharded_views(as_params=False) │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 620 │ │ │ dtype, │
│ 621 │ │ │ flat_param_requires_grad, │
│ 622 │ │ │ device, │
│ ❱ 623 │ │ ) = self._validate_tensors_to_flatten(params) │
│ 624 │ │ params_set = set(params) │
│ 625 │ │ # For alignment padding, only `numels` gets strictly non-`None` │
│ 626 │ │ # elements, and all other lists get `None` elements for padding. │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 773 │ │ │ │ │ "`use_orig_params=False`" │
│ 774 │ │ │ │ ) │
│ 775 │ │ │ if device is not None and tensor.device != device: │
│ ❱ 776 │ │ │ │ raise ValueError( │
│ 777 │ │ │ │ │ "Must flatten tensors on the same device but got both " │
│ 778 │ │ │ │ │ f"{device} and {tensor.device}" │
│ 779 │ │ │ │ ) │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Must flatten tensors on the same device but got both cuda:0 and meta
We are unable to resolve this issue
We tried adding the following line to torch/distributed/fsdp/_init_utils.py
tensor.to("cuda:0")
But this operation gives another error as follows
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 753 │ │ device: Optional[torch.device] = None │
│ 754 │ │ # For `use_orig_params=True`, permit non-uniform `requires_grad` │
│ 755 │ │ for tensor in tensors: │
│ ❱ 756 │ │ │ tensor.to("cuda:0") │
│ 757 │ │ │ if isinstance(tensor, FlatParameter): │
│ 758 │ │ │ │ raise ValueError("Cannot flatten a `FlatParameter`") │
│ 759 │ │ │ if dtype is None and not tensor.is_floating_point(): │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NotImplementedError: Cannot copy out of meta tensor; no data!
We have made other changes to pile-tiny.yaml , scripts/train.py and scripts/util.py to make it compatible for training
I am attaching a zip of those files here :
changes.zip
Apart from this we were facing another issue
[2024-04-18 09:30:56] CRITICAL [root, rank=0] Uncaught AttributeError: 'LayerNorm' object has no attribute 'reset_parameters'
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:229 in <modul │
│ │
│ 226 │ │ raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]") │
│ 227 │ print([clean_opt(s) for s in args_list]) │
│ 228 │ cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list]) │
│ ❱ 229 │ main(cfg) │
│ 230 │
│ │
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:108 in main │
│ │
│ 105 │ log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embeddin │
│ 106 │ torch.distributed.init_process_group(backend='nccl',rank=0, world_size=1) │
│ 107 │ # Wrap the model in FSDP. │
│ ❱ 108 │ fsdp_model = FSDP( │
│ 109 │ │ olmo_model, │
│ 110 │ │ sharding_strategy=cfg.fsdp.sharding_strategy, │
│ 111 │ │ mixed_precision=MixedPrecision( # equivalent to MosaicML's "PURE" │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 474 │ │ │ │ # process groups. │
│ 475 │ │ │ │ root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) │
│ 476 │ │ │ │
│ ❱ 477 │ │ │ _auto_wrap( │
│ 478 │ │ │ │ module, │
│ 479 │ │ │ │ auto_wrap_policy, │
│ 480 │ │ │ │ self._ignored_modules, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 98 │ │ ) │
│ 99 │ │ recursive_wrap_kwargs["auto_wrap_policy"] = policy │
│ 100 │ │ _warn_on_overridden_mixed_precision(overridden_module_classes) │
│ ❱ 101 │ _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] │
│ 102 │
│ 103 │
│ 104 def _check_nested_wrapping(root_module: nn.Module): │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 540 │ │ for name, child in module.named_children(): │
│ 541 │ │ │ if child in ignored_modules: │
│ 542 │ │ │ │ continue │
│ ❱ 543 │ │ │ wrapped_child, num_wrapped_params = _recursive_wrap( │
│ 544 │ │ │ │ module=child, │
│ 545 │ │ │ │ auto_wrap_policy=auto_wrap_policy, │
│ 546 │ │ │ │ wrapper_cls=wrapper_cls, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 540 │ │ for name, child in module.named_children(): │
│ 541 │ │ │ if child in ignored_modules: │
│ 542 │ │ │ │ continue │
│ ❱ 543 │ │ │ wrapped_child, num_wrapped_params = _recursive_wrap( │
│ 544 │ │ │ │ module=child, │
│ 545 │ │ │ │ auto_wrap_policy=auto_wrap_policy, │
│ 546 │ │ │ │ wrapper_cls=wrapper_cls, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 540 │ │ for name, child in module.named_children(): │
│ 541 │ │ │ if child in ignored_modules: │
│ 542 │ │ │ │ continue │
│ ❱ 543 │ │ │ wrapped_child, num_wrapped_params = _recursive_wrap( │
│ 544 │ │ │ │ module=child, │
│ 545 │ │ │ │ auto_wrap_policy=auto_wrap_policy, │
│ 546 │ │ │ │ wrapper_cls=wrapper_cls, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 558 │ │ │ module=module, recurse=False, nonwrapped_numel=remainder │
│ 559 │ │ ): │
│ 560 │ │ │ # Leaf node or final wrapping of the remainder both happen here. │
│ ❱ 561 │ │ │ return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel │
│ 562 │ │ else: │
│ 563 │ │ │ return module, total_wrapped_numel │
│ 564 │ return module, 0 │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 487 │ │ overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type] │
│ 488 │ │ return wrapper_cls(module, **overrides) │
│ 489 │ │
│ ❱ 490 │ return wrapper_cls(module, **kwargs) │
│ 491 │
│ 492 │
│ 493 def _recursive_wrap( │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 500 │ │ _init_buffer_state(self, module) │
│ 501 │ │ # extension needs to be set before `_init_param_handle_from_module()` │
│ 502 │ │ _init_extension(self, device_mesh) │
│ ❱ 503 │ │ _init_param_handle_from_module( │
│ 504 │ │ │ self, │
│ 505 │ │ │ module, │
│ 506 │ │ │ device_id, │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 549 │ │ │ fully_sharded_module, param_init_fn, state._ignored_modules │
│ 550 │ │ ) │
│ 551 │ elif is_meta_module: │
│ ❱ 552 │ │ _materialize_meta_module( │
│ 553 │ │ │ fully_sharded_module, device_id, state._ignored_modules │
│ 554 │ │ ) │
│ 555 │ elif is_torchdistX_deferred_init: │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 881 │ │ │ f"device with error {str(e)}. Please ensure that your module of" │
│ 882 │ │ │ f"type {type(module)} implements a `reset_parameters()` method." │
│ 883 │ │ ) │
│ ❱ 884 │ │ raise e │
│ 885 │
│ 886 │
│ 887 def _get_modules_to_materialize( │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│ │
│ 874 │ │ │ │ has_module_states = len(list(module_state_iter)) > 0 │
│ 875 │ │ │ │ if has_module_states: │
│ 876 │ │ │ │ │ module.to_empty(device=materialization_device, recurse=False) │
│ ❱ 877 │ │ │ │ │ module.reset_parameters() # type: ignore[operator] │
│ 878 │ except BaseException as e: │
│ 879 │ │ warnings.warn( │
│ 880 │ │ │ "Unable to call `reset_parameters()` for module on meta " │
│ │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/nn/modules/modu │
│ │
│ 1685 │ │ │ modules = self.__dict__['_modules'] │
│ 1686 │ │ │ if name in modules: │
│ 1687 │ │ │ │ return modules[name] │
│ ❱ 1688 │ │ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") │
│ 1689 │ │
│ 1690 │ def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: │
│ 1691 │ │ def remove_from(*dicts_or_sets): │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'LayerNorm' object has no attribute 'reset_parameters'
However we circumvented this issue by commenting out the raise error (within torch/distributed/fsdp/_init_utils.py ) as follows
except BaseException as e:
warnings.warn(
"Unable to call `reset_parameters()` for module on meta "
f"device with error {str(e)}. Please ensure that your module of"
f"type {type(module)} implements a `reset_parameters()` method."
)
#raise e
I have attached the entire file within changes.zip , just in case