Skip to content

Conversation

@chaoming0625
Copy link
Member

@chaoming0625 chaoming0625 commented Jan 1, 2026

Summary by Sourcery

Introduce a comprehensive State hook system, enhance nn.Module and Param abstractions with PyTorch-like APIs and parameter caching, upgrade delay and transform/mapping utilities, and reorganize public exports across core, nn, random, graph, transform, and util modules.

New Features:

  • Add a flexible State hook system with per-state and global hooks, priorities, cancellation, error handling modes, and context-manager utilities.
  • Introduce nn.Param and Const parameter modules with support for value transforms, regularization, and automatic cache management of transformed values.
  • Add rich interpolation support for Delay/StateWithDelay including nearest, linear, cubic, Hermite, and polynomial methods via an InterpolationRegistry.
  • Provide higher-level mapping utilities (StateAxes, model_vmap, model_pmap) and hierarchical data containers via nn.HiData for structured state/parameter management.

Enhancements:

  • Refine nn.Module with clearer documentation, PyTorch-style traversal/introspection APIs (children, modules, parameters), regularization aggregation, and parameter precompute control.
  • Extend Sequential with extend and insert for dynamic model composition and automatic size propagation.
  • Unify Delay buffering semantics around a ring-buffer with write pointer state, remove concat mode, and improve batched delay handling and thread safety.
  • Restructure nn, transform, random, graph, and util package exports into explicit, curated all lists to provide a more stable and discoverable public API surface.
  • Tighten State semantics with tag sets, copy/replace helpers, JAX trace checking, and integration points for the new hook system and vmap2 mapping.
  • Improve vmap2/map implementations to better track state axes, random keys, and hierarchy-aware state batching, and add eval_shape support for stateful functions.

Tests:

  • Add extensive unit tests for State hooks, thread safety, and the hook manager, covering priorities, error modes, cancellation, global vs instance hooks, and integration scenarios.
  • Add broad nn.Module tests for parameter and module traversal, regularization loss aggregation, PyTorch-style iteration patterns, and Sequential editing.
  • Introduce a dedicated Param caching test suite validating cache validity, invalidation on writes, transform interaction, thread safety, and performance benefits.
  • Add tests for new Delay interpolation methods, ring-buffer behavior, and upgraded delay API, plus mapping and graph context tests for vmap2 and model_vmap/model_pmap.
  • Add tests for new nn utilities including HiData hierarchical containers, Transform variants, and Regularization classes to ensure round-trip behavior and loss correctness.

Copy link
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @chaoming0625, your pull request is larger than the review limit of 150000 diff characters

@chaoming0625
Copy link
Member Author

@sourcery-ai title

@sourcery-ai sourcery-ai bot changed the title Update Add State hook system and refactor nn modules and transforms Jan 1, 2026
@chaoming0625
Copy link
Member Author

@sourcery-ai summary

@chaoming0625 chaoming0625 merged commit 091d342 into main Jan 1, 2026
5 checks passed
@chaoming0625 chaoming0625 deleted the update branch January 1, 2026 14:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants