Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions parametricmatrixmodels/nonsequentialmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,16 @@ def _get_callable(
self._get_module_input_dependencies()
)

# Check if we can use jax.lax.scan for a chain of identical modules.
# This traces the module body once instead of N times, dramatically
# reducing JIT compile time.
use_scan = self._can_use_scan(module_input_deps, out_input_deps)

if use_scan:
return self._get_scan_callable(
module_callables, module_input_deps, out_input_deps
)

@jaxtyped(typechecker=beartype)
def nonseq_callable(
params: ModelParams,
Expand Down Expand Up @@ -909,6 +919,210 @@ def nonseq_callable(

return nonseq_callable

def _can_use_scan(
self,
module_input_deps: List,
out_input_deps: PyTree,
) -> bool:
r"""
Check if the execution order forms a simple chain of identical
modules that can be optimized with ``jax.lax.scan``.

Conditions:
1. All modules have the same type.
2. All modules have the same parameter tree structure and leaf shapes.
3. All modules have the same state tree structure and leaf shapes.
4. The connections form a simple linear chain.
5. There are at least 2 modules.
"""
import jax.numpy as np

module_paths = [
p
for p in self.execution_order
if p != "input" and p != "output"
]
if len(module_paths) < 2:
return False

modules_list = [
getitem_by_strpath(
self.modules,
mod_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
)
for mod_path in module_paths
]

# Condition 1: all same type
first_type = type(modules_list[0])
if not all(type(m) is first_type for m in modules_list):
return False

# Condition 2: same param structure and leaf shapes
try:
first_params = modules_list[0].get_params()
first_struct = jax.tree.structure(first_params)
first_shapes = [
p.shape for p in jax.tree.leaves(first_params)
]
for m in modules_list[1:]:
p = m.get_params()
if jax.tree.structure(p) != first_struct:
return False
shapes = [leaf.shape for leaf in jax.tree.leaves(p)]
if shapes != first_shapes:
return False
except Exception:
return False

# Condition 3: same state structure and leaf shapes
try:
first_state = modules_list[0].get_state()
first_state_struct = jax.tree.structure(first_state)
first_state_shapes = [
s.shape for s in jax.tree.leaves(first_state)
]
for m in modules_list[1:]:
s = m.get_state()
if jax.tree.structure(s) != first_state_struct:
return False
shapes = [leaf.shape for leaf in jax.tree.leaves(s)]
if shapes != first_state_shapes:
return False
except Exception:
return False

# Condition 4: simple linear chain
# module_input_deps[0] is None (for "input")
# module_input_deps[1] should be "input" (first module)
# module_input_deps[i] should be module_paths[i-2] for i >= 2
# out_input_deps should be module_paths[-1]
dep_idx = 0
for i, (mod_path, req_in) in enumerate(
zip(self.execution_order, module_input_deps)
):
if mod_path == "input":
continue
if mod_path == "output":
continue
if dep_idx == 0:
# First module should depend on "input"
if req_in != "input":
return False
else:
# Each subsequent module should depend on previous
if req_in != module_paths[dep_idx - 1]:
return False
dep_idx += 1

# Output should depend on the last module
if out_input_deps != module_paths[-1]:
return False

return True

def _get_scan_callable(
self,
module_callables: List,
module_input_deps: List,
out_input_deps: PyTree,
) -> ModelCallable:
r"""
Build a scan-based callable for a chain of identical modules.
Uses ``jax.lax.scan`` to trace the module body once instead of N
times, dramatically reducing JIT compile time.
"""
import jax.numpy as np
from jax import lax

module_paths = [
p
for p in self.execution_order
if p != "input" and p != "output"
]
num_modules = len(module_paths)

# Use the first module's callable (all are identical)
single_callable = module_callables[1] # [0] is None for "input"

@jaxtyped(typechecker=beartype)
def scan_nonseq_callable(
params: ModelParams,
data: Data,
training: bool,
state: ModelState,
rng: Any,
) -> Tuple[Data, ModelState]:

# Extract each module's params and stack them
module_params_list = [
getitem_by_strpath(
params,
mod_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
)
for mod_path in module_paths
]
stacked_params = jax.tree.map(
lambda *ps: np.stack(ps), *module_params_list
)

# Extract each module's state and stack them
module_states_list = [
getitem_by_strpath(
state,
mod_path,
separator=self.separator,
allow_early_return=False,
return_remainder=False,
)
for mod_path in module_paths
]
stacked_states = jax.tree.map(
lambda *ss: np.stack(ss), *module_states_list
)

# Scan body: process data through one module
def scan_body(carry, step_input):
step_data = carry
step_params, step_state = step_input
out, new_step_state = single_callable(
step_params,
step_data,
training,
step_state,
rng,
)
return out, new_step_state

final_data, new_stacked_states = lax.scan(
scan_body,
data,
(stacked_params, stacked_states),
)

# Reconstruct state PyTree by unstacking
new_state = state
for i, mod_path in enumerate(module_paths):
mod_new_state = jax.tree.map(
lambda s: s[i], new_stacked_states
)
setitem_by_strpath(
new_state,
mod_path,
mod_new_state,
separator=self.separator,
)

return final_data, new_state

return scan_nonseq_callable

def get_state(self) -> ModelState:
r"""
Get the state of all modules in the model as a PyTree.
Expand Down
Loading