Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions predai/rootfs/predai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
OrderedDict,
])

# Monkey-patch torch.load to use weights_only=False by default
# Monkey-patch torch.load to force weights_only=False
# This is needed for PyTorch 2.6+ compatibility with PyTorch Lightning
# PyTorch Lightning explicitly passes weights_only=True, so we need to override it
Comment on lines +46 to +48
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "PyTorch Lightning explicitly passes weights_only=True, so we need to override it", which suggests the approach is designed to bypass PyTorch Lightning's security setting. This is a concerning design decision. PyTorch Lightning likely passes weights_only=True intentionally for security reasons.

Instead of forcefully overriding PyTorch Lightning's security choice, consider using the add_safe_globals approach (which is already implemented in this file at lines 34-44) more comprehensively, or work with PyTorch Lightning's configuration options to handle checkpoint loading differently.

Copilot uses AI. Check for mistakes.
_original_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
if 'weights_only' not in kwargs:
kwargs['weights_only'] = False
# Force weights_only=False for all checkpoint loads
# This is safe because we're loading locally-created NeuralProphet checkpoints
Comment on lines +51 to +52
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The claim "This is safe because we're loading locally-created NeuralProphet checkpoints" is not enforced by the code. The monkey patch applies to ALL calls to torch.load() throughout the entire application, regardless of the source of the checkpoint file or what library is calling it.

If the application ever loads checkpoints from external sources, or if any dependency uses torch.load(), this blanket override removes protection against malicious pickle files. The safety claim in the comment doesn't match the actual implementation.

Copilot uses AI. Check for mistakes.
kwargs['weights_only'] = False
Comment on lines 50 to +53
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditionally forcing weights_only=False overrides any security-conscious callers (like PyTorch Lightning) that explicitly set weights_only=True. This approach disables PyTorch's security safeguards for all checkpoint loads in the application, not just NeuralProphet checkpoints.

Consider a more targeted approach that preserves the caller's intent when explicitly specified. For example, you could check if the checkpoint path indicates it's a NeuralProphet checkpoint before overriding, or only apply the override when loading from trusted local paths. A safer approach might be to check if an UnpicklingError occurs with weights_only=True and only then retry with weights_only=False for known safe checkpoint types.

Copilot uses AI. Check for mistakes.
return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
except (ImportError, AttributeError):
Expand Down
5 changes: 3 additions & 2 deletions test_predai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import torch
_original_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
if 'weights_only' not in kwargs:
kwargs['weights_only'] = False
# Force weights_only=False for all checkpoint loads
# This is safe because we're loading locally-created NeuralProphet checkpoints
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The claim "This is safe because we're loading locally-created NeuralProphet checkpoints" is not enforced by the code. The monkey patch applies to ALL calls to torch.load() throughout the entire application, regardless of the source of the checkpoint file or what library is calling it.

If the application ever loads checkpoints from external sources, or if any dependency uses torch.load(), this blanket override removes protection against malicious pickle files. The safety claim in the comment doesn't match the actual implementation.

Suggested change
# This is safe because we're loading locally-created NeuralProphet checkpoints
# WARNING: This disables torch.load's weights_only safeguard globally; use only in trusted test environments with known-safe checkpoints.

Copilot uses AI. Check for mistakes.
kwargs['weights_only'] = False
Comment on lines 13 to +16
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditionally forcing weights_only=False overrides any security-conscious callers (like PyTorch Lightning) that explicitly set weights_only=True. This approach disables PyTorch's security safeguards for all checkpoint loads in the application, not just NeuralProphet checkpoints.

Consider a more targeted approach that preserves the caller's intent when explicitly specified. For example, you could check if the checkpoint path indicates it's a NeuralProphet checkpoint before overriding, or only apply the override when loading from trusted local paths. A safer approach might be to check if an UnpicklingError occurs with weights_only=True and only then retry with weights_only=False for known safe checkpoint types.

Copilot uses AI. Check for mistakes.
return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
except (ImportError, AttributeError):
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Suggested change
except (ImportError, AttributeError):
except (ImportError, AttributeError):
# If PyTorch is not available or torch.load is missing, skip patching and
# run the tests without this compatibility adjustment.

Copilot uses AI. Check for mistakes.
Expand Down
Loading