-
Notifications
You must be signed in to change notification settings - Fork 7
Force weights_only=False in torch.load to fix PyTorch 2.6+ checkpoint loading #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,12 +43,14 @@ | |
| OrderedDict, | ||
| ]) | ||
|
|
||
| # Monkey-patch torch.load to use weights_only=False by default | ||
| # Monkey-patch torch.load to force weights_only=False | ||
| # This is needed for PyTorch 2.6+ compatibility with PyTorch Lightning | ||
| # PyTorch Lightning explicitly passes weights_only=True, so we need to override it | ||
| _original_torch_load = torch.load | ||
| def _patched_torch_load(*args, **kwargs): | ||
| if 'weights_only' not in kwargs: | ||
| kwargs['weights_only'] = False | ||
| # Force weights_only=False for all checkpoint loads | ||
| # This is safe because we're loading locally-created NeuralProphet checkpoints | ||
|
Comment on lines
+51
to
+52
|
||
| kwargs['weights_only'] = False | ||
|
Comment on lines
50
to
+53
|
||
| return _original_torch_load(*args, **kwargs) | ||
| torch.load = _patched_torch_load | ||
| except (ImportError, AttributeError): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -11,8 +11,9 @@ | |||||||||
| import torch | ||||||||||
| _original_torch_load = torch.load | ||||||||||
| def _patched_torch_load(*args, **kwargs): | ||||||||||
| if 'weights_only' not in kwargs: | ||||||||||
| kwargs['weights_only'] = False | ||||||||||
| # Force weights_only=False for all checkpoint loads | ||||||||||
| # This is safe because we're loading locally-created NeuralProphet checkpoints | ||||||||||
|
||||||||||
| # This is safe because we're loading locally-created NeuralProphet checkpoints | |
| # WARNING: This disables torch.load's weights_only safeguard globally; use only in trusted test environments with known-safe checkpoints. |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unconditionally forcing weights_only=False overrides any security-conscious callers (like PyTorch Lightning) that explicitly set weights_only=True. This approach disables PyTorch's security safeguards for all checkpoint loads in the application, not just NeuralProphet checkpoints.
Consider a more targeted approach that preserves the caller's intent when explicitly specified. For example, you could check if the checkpoint path indicates it's a NeuralProphet checkpoint before overriding, or only apply the override when loading from trusted local paths. A safer approach might be to check if an UnpicklingError occurs with weights_only=True and only then retry with weights_only=False for known safe checkpoint types.
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'except' clause does nothing but pass and there is no explanatory comment.
| except (ImportError, AttributeError): | |
| except (ImportError, AttributeError): | |
| # If PyTorch is not available or torch.load is missing, skip patching and | |
| # run the tests without this compatibility adjustment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment states "PyTorch Lightning explicitly passes weights_only=True, so we need to override it", which suggests the approach is designed to bypass PyTorch Lightning's security setting. This is a concerning design decision. PyTorch Lightning likely passes
weights_only=Trueintentionally for security reasons.Instead of forcefully overriding PyTorch Lightning's security choice, consider using the
add_safe_globalsapproach (which is already implemented in this file at lines 34-44) more comprehensively, or work with PyTorch Lightning's configuration options to handle checkpoint loading differently.