diff --git a/predai/rootfs/predai.py b/predai/rootfs/predai.py index 00916e9..44c037c 100644 --- a/predai/rootfs/predai.py +++ b/predai/rootfs/predai.py @@ -43,12 +43,14 @@ OrderedDict, ]) - # Monkey-patch torch.load to use weights_only=False by default + # Monkey-patch torch.load to force weights_only=False # This is needed for PyTorch 2.6+ compatibility with PyTorch Lightning + # PyTorch Lightning explicitly passes weights_only=True, so we need to override it _original_torch_load = torch.load def _patched_torch_load(*args, **kwargs): - if 'weights_only' not in kwargs: - kwargs['weights_only'] = False + # Force weights_only=False for all checkpoint loads + # This is safe because we're loading locally-created NeuralProphet checkpoints + kwargs['weights_only'] = False return _original_torch_load(*args, **kwargs) torch.load = _patched_torch_load except (ImportError, AttributeError): diff --git a/test_predai.py b/test_predai.py index ec8a6e4..982d98e 100644 --- a/test_predai.py +++ b/test_predai.py @@ -11,8 +11,9 @@ import torch _original_torch_load = torch.load def _patched_torch_load(*args, **kwargs): - if 'weights_only' not in kwargs: - kwargs['weights_only'] = False + # Force weights_only=False for all checkpoint loads + # This is safe because we're loading locally-created NeuralProphet checkpoints + kwargs['weights_only'] = False return _original_torch_load(*args, **kwargs) torch.load = _patched_torch_load except (ImportError, AttributeError):