Skip to content

fix: torch.checkpoint() incorrectly wraps single forward step in original codebase.#274

Open
COLAZERO2 wants to merge 1 commit intoSafeAILab:mainfrom
COLAZERO2:fix-grad
Open

fix: torch.checkpoint() incorrectly wraps single forward step in original codebase.#274
COLAZERO2 wants to merge 1 commit intoSafeAILab:mainfrom
COLAZERO2:fix-grad

Conversation

@COLAZERO2
Copy link
Contributor

Bug Fixes:
Fixes bugs that caused the loss to remain high due to unstable gradients when training with gradient checkpointing enabled. After fixing, the acceptance rate increases as intended when using the gradient checkpoint memory optimization trick.

Modification:
Refactors the draft model’s forward function by separating target model hidden state retrieval and the draft model’s layer flow. Wraps the entire training-time test predictions over the drafting length, removing the torch.checkpoint() loops that previously led to a complicated computation graph and incorrect gradient flows.

This caused the loss to remain high due to unstable gradients when training with gradient checkpointing enabled. After fixing, accuracy increases as intended when using the gradient checkpoint memory optimization trick.
@jasonyong
Copy link

It works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants