From b9c867abfe4742b276d4ebd0177df5544e9023f9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 30 Dec 2025 16:11:38 +0000 Subject: [PATCH 1/2] Initial plan From a24760ba0cda2d41b6e95c02ef7455fa4bcca57d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 30 Dec 2025 16:19:02 +0000 Subject: [PATCH 2/2] Fix PyTorch 2.6+ weights_only checkpoint loading by forcing weights_only=False Co-authored-by: springfall2008 <48591903+springfall2008@users.noreply.github.com> --- predai/rootfs/predai.py | 8 +++++--- test_predai.py | 5 +++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/predai/rootfs/predai.py b/predai/rootfs/predai.py index 00916e9..44c037c 100644 --- a/predai/rootfs/predai.py +++ b/predai/rootfs/predai.py @@ -43,12 +43,14 @@ OrderedDict, ]) - # Monkey-patch torch.load to use weights_only=False by default + # Monkey-patch torch.load to force weights_only=False # This is needed for PyTorch 2.6+ compatibility with PyTorch Lightning + # PyTorch Lightning explicitly passes weights_only=True, so we need to override it _original_torch_load = torch.load def _patched_torch_load(*args, **kwargs): - if 'weights_only' not in kwargs: - kwargs['weights_only'] = False + # Force weights_only=False for all checkpoint loads + # This is safe because we're loading locally-created NeuralProphet checkpoints + kwargs['weights_only'] = False return _original_torch_load(*args, **kwargs) torch.load = _patched_torch_load except (ImportError, AttributeError): diff --git a/test_predai.py b/test_predai.py index ec8a6e4..982d98e 100644 --- a/test_predai.py +++ b/test_predai.py @@ -11,8 +11,9 @@ import torch _original_torch_load = torch.load def _patched_torch_load(*args, **kwargs): - if 'weights_only' not in kwargs: - kwargs['weights_only'] = False + # Force weights_only=False for all checkpoint loads + # This is safe because we're loading locally-created NeuralProphet checkpoints + kwargs['weights_only'] = False return _original_torch_load(*args, **kwargs) torch.load = _patched_torch_load except (ImportError, AttributeError):