-
Notifications
You must be signed in to change notification settings - Fork 88
Integration with DCP #978
base: unflatten
Are you sure you want to change the base?
Integration with DCP #978
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates | ||
| import torch | ||
| from pippy import annotate_split_points, Pipe, SplitPoint | ||
| import torch.distributed.checkpoint as dcp | ||
| import tempfile | ||
|
|
||
|
|
||
| d_hid = 16 | ||
|
|
@@ -66,6 +68,49 @@ def get_layers(module): | |
| return layers | ||
|
|
||
|
|
||
| def pipe_to_sd(pipe): | ||
| sd = {} | ||
| for stage_idx in range(pipe.num_stages): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something a little fishy about this proposal (equally so for both option 1 and 2) is that it's not likely you'd want to iterate all the stages in the pipe and load/save them. Example 1: simple pipeline with 4 gpus |
||
| stage_mod = pipe.get_stage_module(stage_idx) | ||
| sd[f"stage_{stage_idx}"] = stage_mod | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not really clear to me why we need to add a prefix at all. There should be no duplication of fqns between submods/stages. what are we doing about the 'submod_0' part in the fqn? when we do If the former, can't we just save/load the keys as usual? If the latter, we can still save/load without a prefix of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Former. @wconstab |
||
| return sd | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| #Simulate saving the pipe | ||
| # Option 1: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Option 1 would be more likely used than Option 2 in realistic setting. Could you please uncomment this block of code? |
||
| # for stage_idx in range(pipe.num_stages): | ||
| # print(f"Saving pipeline stage {stage_idx}") | ||
| # stage_mod = pipe.get_stage_module(stage_idx) | ||
| # dcp.save( | ||
| # {f"stage_{stage_idx}": stage_mod}, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, is the dict required by API of DCP? Can a user directly save
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does this matter? i think the DCP api had reasons for interfacing with dict instead of model, adding a new variant that takes model and gets its dict should be possible, but i think it's clearer this way that the only part of the model that gets saved is the dict
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to be clear: I like saving the state dict too (instead of the module). That's more composable to me. |
||
| # checkpoint_id=f"{tmpdir}_{stage_idx}" | ||
| # ) | ||
| # Option 2: | ||
| sd = pipe_to_sd(pipe) | ||
| dcp.save(state_dict, checkpoint_id=tmpdir) | ||
|
|
||
|
|
||
| #Simulate loading the pipe | ||
| # Option 1: | ||
| # for stage_idx in range(pipe.num_stages): | ||
| # print(f"Loading pipeline stage {stage_idx}") | ||
| # stage_mod = pipe.get_stage_module(stage_idx) | ||
| # dcp.load( | ||
| # {f"stage_{stage_idx}": stage_mod}, | ||
| # checkpoint_id=f"{tmpdir}_{stage_idx}" | ||
| # ) | ||
|
|
||
| #Option 2: | ||
| new_pipe = Pipe.from_tracing( | ||
| transformer, | ||
| 1, | ||
| (x,), | ||
| ) | ||
| sd = pipe_to_sd(new_pipe) | ||
| dcp.load(sd, checkpoint_id=tmpdir) | ||
|
|
||
| pipe = new_pipe | ||
|
|
||
| # Collect all layers in pipe | ||
| layers = [] | ||
| for stage_idx in range(pipe.num_stages): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wz337 , might be interesting in dist state dict