From 44ccb8ee956c1f5e1839fed5d28459aa170c1314 Mon Sep 17 00:00:00 2001 From: Shusheng Yang Date: Fri, 11 Jul 2025 01:13:54 +0000 Subject: [PATCH] fallback __getattr__ to _orig_mod for ShardedModule --- torchprime/sharding/shard_model.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchprime/sharding/shard_model.py b/torchprime/sharding/shard_model.py index bf2030e6..737c64a1 100644 --- a/torchprime/sharding/shard_model.py +++ b/torchprime/sharding/shard_model.py @@ -315,3 +315,13 @@ def __init__(self, mod, mark_sharding, spec): def forward(self, *args, **kwargs): return self.mark_sharding(self._orig_mod(*args, **kwargs), self.spec) + + @property + def module(self): + return self._orig_mod + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name)