Skip to content

Issues with FineTuning Checkpoint #2

@Advaid-Deepak

Description

@Advaid-Deepak

We were trying to finetune a Matformer checkpoint ( MatFormer-OLMo-180M Link )

We used the following command to call the training script

python train.py ../configs/pile-tiny.yaml \
    --matformer_factor=8 \
    --matformer_factor=8 \
    --model.d_model=512 \
    --model.n_heads=16 \
    --model.n_layers=8 \
    --model.max_sequence_length=2048 \
    --device_train_microbatch_size=8 \
    --global_train_batch_size=128 \
    --max_duration=75000  \
    --optimizer.learning_rate=1.0e-3 \
    --console_log_interval=10 \
    --load_path=:"/raid/ganesh/namitha/Skill_localization_experiment/ckpt_paths/MatFormer-OLMo-180M" \
    --run_name="matformer-olmo-180M-finetune"

where the folder mentioned in load_path is obtained by download from the link mentioned in the README for MatFormer-OLMo-180M .

However running this gives us the following error

[2024-04-18 09:09:04] CRITICAL [root, rank=0] Uncaught ValueError: Must flatten tensors on the same device but got both cuda:0 and meta
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:229 in <modul │
│                                                                                                  │
│   226 │   │   raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")                │
│   227 │   print([clean_opt(s) for s in args_list])                                               │
│   228 │   cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])                   │
│ ❱ 229 │   main(cfg)                                                                              │
│   230                                                                                            │
│                                                                                                  │
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:108 in main   │
│                                                                                                  │
│   105 │   log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embeddin │
│   106 │   torch.distributed.init_process_group(backend='nccl',rank=0, world_size=1)              │
│   107 │   # Wrap the model in FSDP.                                                              │
│ ❱ 108 │   fsdp_model = FSDP(                                                                     │
│   109 │   │   olmo_model,                                                                        │
│   110 │   │   sharding_strategy=cfg.fsdp.sharding_strategy,                                      │
│   111 │   │   mixed_precision=MixedPrecision(  # equivalent to MosaicML's "PURE"                 │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    474 │   │   │   │   # process groups.                                                         │
│    475 │   │   │   │   root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)  │
│    476 │   │   │                                                                                 │
│ ❱  477 │   │   │   _auto_wrap(                                                                   │
│    478 │   │   │   │   module,                                                                   │
│    479 │   │   │   │   auto_wrap_policy,                                                         │
│    480 │   │   │   │   self._ignored_modules,                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    98 │   │   )                                                                                  │
│    99 │   │   recursive_wrap_kwargs["auto_wrap_policy"] = policy                                 │
│   100 │   │   _warn_on_overridden_mixed_precision(overridden_module_classes)                     │
│ ❱ 101 │   _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]      │
│   102                                                                                            │
│   103                                                                                            │
│   104 def _check_nested_wrapping(root_module: nn.Module):                                        │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   558 │   │   │   module=module, recurse=False, nonwrapped_numel=remainder                       │
│   559 │   │   ):                                                                                 │
│   560 │   │   │   # Leaf node or final wrapping of the remainder both happen here.               │
│ ❱ 561 │   │   │   return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel                  │
│   562 │   │   else:                                                                              │
│   563 │   │   │   return module, total_wrapped_numel                                             │
│   564 │   return module, 0                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   487 │   │   overrides = {**kwargs, **module._wrap_overrides}  # type: ignore[arg-type]         │
│   488 │   │   return wrapper_cls(module, **overrides)                                            │
│   489 │                                                                                          │
│ ❱ 490 │   return wrapper_cls(module, **kwargs)                                                   │
│   491                                                                                            │
│   492                                                                                            │
│   493 def _recursive_wrap(                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    500 │   │   _init_buffer_state(self, module)                                                  │
│    501 │   │   # extension needs to be set before `_init_param_handle_from_module()`             │
│    502 │   │   _init_extension(self, device_mesh)                                                │
│ ❱  503 │   │   _init_param_handle_from_module(                                                   │
│    504 │   │   │   self,                                                                         │
│    505 │   │   │   module,                                                                       │
│    506 │   │   │   device_id,                                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    587 │   │   │   _sync_module_params_and_buffers(                                              │
│    588 │   │   │   │   fully_sharded_module, managed_params, state._inter_node_pg                │
│    589 │   │   │   )                                                                             │
│ ❱  590 │   _init_param_handle_from_params(state, managed_params, fully_sharded_module)           │
│    591 │   return state                                                                          │
│    592                                                                                           │
│    593                                                                                           │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    599 ):                                                                                        │
│    600 │   if len(params) == 0:                                                                  │
│    601 │   │   return                                                                            │
│ ❱  602 │   handle = FlatParamHandle(                                                             │
│    603 │   │   params,                                                                           │
│    604 │   │   fully_sharded_module,                                                             │
│    605 │   │   state.compute_device,                                                             │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    570 │   │   │   else 0                                                                        │
│    571 │   │   )                                                                                 │
│    572 │   │   self._fsdp_extension = fsdp_extension                                             │
│ ❱  573 │   │   self._init_flat_param_and_metadata(                                               │
│    574 │   │   │   params, fully_sharded_module, self._aligned_numel, use_orig_params  # type: i │
│    575 │   │   )                                                                                 │
│    576 │   │   self._use_unsharded_views(as_params=False)                                        │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    620 │   │   │   dtype,                                                                        │
│    621 │   │   │   flat_param_requires_grad,                                                     │
│    622 │   │   │   device,                                                                       │
│ ❱  623 │   │   ) = self._validate_tensors_to_flatten(params)                                     │
│    624 │   │   params_set = set(params)                                                          │
│    625 │   │   # For alignment padding, only `numels` gets strictly non-`None`                   │
│    626 │   │   # elements, and all other lists get `None` elements for padding.                  │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    773 │   │   │   │   │   "`use_orig_params=False`"                                             │
│    774 │   │   │   │   )                                                                         │
│    775 │   │   │   if device is not None and tensor.device != device:                            │
│ ❱  776 │   │   │   │   raise ValueError(                                                         │
│    777 │   │   │   │   │   "Must flatten tensors on the same device but got both "               │
│    778 │   │   │   │   │   f"{device} and {tensor.device}"                                       │
│    779 │   │   │   │   )                                                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Must flatten tensors on the same device but got both cuda:0 and meta

We are unable to resolve this issue

We tried adding the following line to torch/distributed/fsdp/_init_utils.py

tensor.to("cuda:0")  

But this operation gives another error as follows

│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    753 │   │   device: Optional[torch.device] = None                                             │
│    754 │   │   # For `use_orig_params=True`, permit non-uniform `requires_grad`                  │
│    755 │   │   for tensor in tensors:                                                            │
│ ❱  756 │   │   │   tensor.to("cuda:0")                                                           │
│    757 │   │   │   if isinstance(tensor, FlatParameter):                                         │
│    758 │   │   │   │   raise ValueError("Cannot flatten a `FlatParameter`")                      │
│    759 │   │   │   if dtype is None and not tensor.is_floating_point():                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NotImplementedError: Cannot copy out of meta tensor; no data!

We have made other changes to pile-tiny.yaml , scripts/train.py and scripts/util.py to make it compatible for training
I am attaching a zip of those files here :
changes.zip

Apart from this we were facing another issue

[2024-04-18 09:30:56] CRITICAL [root, rank=0] Uncaught AttributeError: 'LayerNorm' object has no attribute 'reset_parameters'
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:229 in <modul │
│                                                                                                  │
│   226 │   │   raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")                │
│   227 │   print([clean_opt(s) for s in args_list])                                               │
│   228 │   cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])                   │
│ ❱ 229 │   main(cfg)                                                                              │
│   230                                                                                            │
│                                                                                                  │
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:108 in main   │
│                                                                                                  │
│   105 │   log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embeddin │
│   106 │   torch.distributed.init_process_group(backend='nccl',rank=0, world_size=1)              │
│   107 │   # Wrap the model in FSDP.                                                              │
│ ❱ 108 │   fsdp_model = FSDP(                                                                     │
│   109 │   │   olmo_model,                                                                        │
│   110 │   │   sharding_strategy=cfg.fsdp.sharding_strategy,                                      │
│   111 │   │   mixed_precision=MixedPrecision(  # equivalent to MosaicML's "PURE"                 │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    474 │   │   │   │   # process groups.                                                         │
│    475 │   │   │   │   root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)  │
│    476 │   │   │                                                                                 │
│ ❱  477 │   │   │   _auto_wrap(                                                                   │
│    478 │   │   │   │   module,                                                                   │
│    479 │   │   │   │   auto_wrap_policy,                                                         │
│    480 │   │   │   │   self._ignored_modules,                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    98 │   │   )                                                                                  │
│    99 │   │   recursive_wrap_kwargs["auto_wrap_policy"] = policy                                 │
│   100 │   │   _warn_on_overridden_mixed_precision(overridden_module_classes)                     │
│ ❱ 101 │   _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]      │
│   102                                                                                            │
│   103                                                                                            │
│   104 def _check_nested_wrapping(root_module: nn.Module):                                        │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   558 │   │   │   module=module, recurse=False, nonwrapped_numel=remainder                       │
│   559 │   │   ):                                                                                 │
│   560 │   │   │   # Leaf node or final wrapping of the remainder both happen here.               │
│ ❱ 561 │   │   │   return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel                  │
│   562 │   │   else:                                                                              │
│   563 │   │   │   return module, total_wrapped_numel                                             │
│   564 │   return module, 0                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   487 │   │   overrides = {**kwargs, **module._wrap_overrides}  # type: ignore[arg-type]         │
│   488 │   │   return wrapper_cls(module, **overrides)                                            │
│   489 │                                                                                          │
│ ❱ 490 │   return wrapper_cls(module, **kwargs)                                                   │
│   491                                                                                            │
│   492                                                                                            │
│   493 def _recursive_wrap(                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    500 │   │   _init_buffer_state(self, module)                                                  │
│    501 │   │   # extension needs to be set before `_init_param_handle_from_module()`             │
│    502 │   │   _init_extension(self, device_mesh)                                                │
│ ❱  503 │   │   _init_param_handle_from_module(                                                   │
│    504 │   │   │   self,                                                                         │
│    505 │   │   │   module,                                                                       │
│    506 │   │   │   device_id,                                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    549 │   │   │   fully_sharded_module, param_init_fn, state._ignored_modules                   │
│    550 │   │   )                                                                                 │
│    551 │   elif is_meta_module:                                                                  │
│ ❱  552 │   │   _materialize_meta_module(                                                         │
│    553 │   │   │   fully_sharded_module, device_id, state._ignored_modules                       │
│    554 │   │   )                                                                                 │
│    555 │   elif is_torchdistX_deferred_init:                                                     │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    881 │   │   │   f"device with error {str(e)}. Please ensure that your module of"              │
│    882 │   │   │   f"type {type(module)} implements a `reset_parameters()` method."              │
│    883 │   │   )                                                                                 │
│ ❱  884 │   │   raise e                                                                           │
│    885                                                                                           │
│    886                                                                                           │
│    887 def _get_modules_to_materialize(                                                          │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    874 │   │   │   │   has_module_states = len(list(module_state_iter)) > 0                      │
│    875 │   │   │   │   if has_module_states:                                                     │
│    876 │   │   │   │   │   module.to_empty(device=materialization_device, recurse=False)         │
│ ❱  877 │   │   │   │   │   module.reset_parameters()  # type: ignore[operator]                   │
│    878 │   except BaseException as e:                                                            │
│    879 │   │   warnings.warn(                                                                    │
│    880 │   │   │   "Unable to call `reset_parameters()` for module on meta "                     │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/nn/modules/modu │
│                                                                                                  │
│   1685 │   │   │   modules = self.__dict__['_modules']                                           │
│   1686 │   │   │   if name in modules:                                                           │
│   1687 │   │   │   │   return modules[name]                                                      │
│ ❱ 1688 │   │   raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") │
│   1689 │                                                                                         │
│   1690 │   def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:             │
│   1691 │   │   def remove_from(*dicts_or_sets):                                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'LayerNorm' object has no attribute 'reset_parameters'

However we circumvented this issue by commenting out the raise error (within torch/distributed/fsdp/_init_utils.py ) as follows

except BaseException as e:
        warnings.warn(
            "Unable to call `reset_parameters()` for module on meta "
            f"device with error {str(e)}. Please ensure that your module of"
            f"type {type(module)} implements a `reset_parameters()` method."
        )
        #raise e

I have attached the entire file within changes.zip , just in case

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions