Skip to content

Conversation

@chaoming0625
Copy link
Member

@chaoming0625 chaoming0625 commented Jan 2, 2026

Summary by Sourcery

Update vectorization tutorial and mapping API to reflect the vmap2 transform instead of vmap.

Enhancements:

  • Relax the __call__ method return type in the stateful mapping implementation to be unconstrained for greater flexibility.

Documentation:

  • Retarget the vectorization tutorial to vmap2, updating narrative text, headings, and code examples accordingly.

@sourcery-ai
Copy link
Contributor

sourcery-ai bot commented Jan 2, 2026

Reviewer's Guide

Updates the vectorization tutorial and references to consistently document vmap2 instead of vmap, adjusts examples/imports accordingly, normalizes an embedded Jaxpr text block’s escape sequences, and relaxes the return annotation of the StatefulMapping.__call__ implementation.

Class diagram for StatefulMapping and vmap2 after call signature change

classDiagram
    class StatefulMapping {
        +__call__(*args, **kwargs) Any
    }

    class vmap2 {
    }

    vmap2 --> StatefulMapping : wraps
Loading

File-Level Changes

Change Details Files
Rename tutorial and examples from vmap to vmap2 and update code usage accordingly.
  • Update introductory bullet list and section headings to refer to vmap2 instead of vmap.
  • Change import to use vmap2 directly instead of aliasing it as vmap.
  • Replace all example usages of vmap(...) with vmap2(...) across the notebook, including nested and composed transform examples.
  • Fix a typo in the example where the compiled mapping is invoked, aligning the call with the vmap2 naming.
docs/tutorials/transforms/03_vectorization.ipynb
Normalize formatting of embedded Jaxpr output snippet.
  • Adjust ANSI escape sequences in the captured Jaxpr text to use consistent escape formatting.
  • Ensure the compiled Jaxpr printout matches current tooling output while remaining readable in the notebook.
docs/tutorials/transforms/03_vectorization.ipynb
Generalize StatefulMapping.call return type.
  • Remove the concrete Tuple[Any, Tuple[State, ...]] return type annotation from call in the stateful mapping implementation so it can return arbitrary structures.
  • Keep the implementation logic intact, only relaxing the type signature.
brainstate/transform/_mapping2.py
Add or update API documentation stubs for several brainstate modules.
  • Touch or regenerate RST API index files for the brainstate, environ, nn, and transform modules, likely to reflect the vmap2-facing API surface or fix doc build issues.
docs/apis/brainstate.rst
docs/apis/environ.rst
docs/apis/nn.rst
docs/apis/transform.rst

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@chaoming0625
Copy link
Member Author

@sourcery-ai title

Copy link
Contributor

@sourcery-ai sourcery-ai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 4 issues, and left some high level feedback:

  • In the tutorial notebook you renamed the imported symbol to vmap2 but left some call sites inconsistent (e.g., simple_vmap is defined but later _ = simple_vmap2(test_input) is called), which will raise a NameError when executing the notebook; please align the variable names with the imported function name throughout.
  • You removed the return type annotation from StatefulMapping.__call__ in _mapping2.py; if the method still always returns a (output, states) tuple, keeping or updating a precise return type would help static checking and IDE support.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In the tutorial notebook you renamed the imported symbol to `vmap2` but left some call sites inconsistent (e.g., `simple_vmap` is defined but later `_ = simple_vmap2(test_input)` is called), which will raise a NameError when executing the notebook; please align the variable names with the imported function name throughout.
- You removed the return type annotation from `StatefulMapping.__call__` in `_mapping2.py`; if the method still always returns a `(output, states)` tuple, keeping or updating a precise return type would help static checking and IDE support.

## Individual Comments

### Comment 1
<location> `docs/tutorials/transforms/03_vectorization.ipynb:1105` </location>
<code_context>
     "# Call once to trigger compilation\n",
     "test_input = jnp.array([1.0, 2.0])\n",
-    "_ = simple_vmap(test_input)\n",
+    "_ = simple_vmap2(test_input)\n",
     "\n",
     "# Access the compiled Jaxpr\n",
</code_context>

<issue_to_address>
**issue (bug_risk):** Likely typo: calling `simple_vmap2` instead of the defined `simple_vmap`.

`simple_vmap` is defined as `simple_vmap = vmap2(...)` above, but `simple_vmap2` is not defined, so this call will raise a `NameError`. This line should call `simple_vmap(test_input)` instead.
</issue_to_address>

### Comment 2
<location> `brainstate/transform/_mapping2.py:518` </location>
<code_context>
         return tuple(rand_vals), tuple(rand_recover_vals)

-    def __call__(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
+    def __call__(self, *args, **kwargs):
         """Execute the stateful mapping on the given arguments.

</code_context>

<issue_to_address>
**suggestion:** Dropping the return type annotation on `__call__` reduces type clarity for callers.

The previous signature `__call__(...) -> Tuple[Any, Tuple[State, ...]]` provided useful information for type checkers and IDEs. Unless it was incorrect, please restore it; if it was wrong, update it to the correct return type instead of removing it entirely.

```suggestion
    def __call__(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
```
</issue_to_address>

### Comment 3
<location> `docs/tutorials/transforms/03_vectorization.ipynb:1213-1216` </location>
<code_context>
     "\n",
     "\n",
     "# Compose: jit -> grad -> vmap\n",
-    "batched_grad = vmap(\n",
+    "batched_grad = vmap2(\n",
</code_context>

<issue_to_address>
**nitpick (typo):** Comment still refers to `vmap` while the tutorial is now consistently about `vmap2`.

This line still mentions `vmap`. Please update the comment to `# Compose: jit -> grad -> vmap2` to match the updated API and avoid confusion.

```suggestion
    "\n",
    "\n",
    "# Compose: jit -> grad -> vmap2\n",
    "batched_grad = vmap2(\n",
```
</issue_to_address>

### Comment 4
<location> `docs/apis/nn.rst:360` </location>
<code_context>

+Functional (non-module) activation functions for flexible composition. These are
+pure functions that can be used directly in ``update()`` methods or combined with
+JAX transformations. Provides the same activations as the layer-based equivalents
+but without state or module overhead.
+
</code_context>

<issue_to_address>
**issue (typo):** Fix the sentence fragment starting with "Provides" to have an explicit subject.

In the phrase "JAX transformations. Provides the same activations as the layer-based equivalents but without state or module overhead.", the second sentence is a fragment. Consider revising to something like "They provide the same activations as the layer-based equivalents but without state or module overhead." or "These functions provide the same activations…" so the sentence has an explicit subject and correct agreement.

Suggested implementation:

```
Functional (non-module) activation functions for flexible composition. These are
pure functions that can be used directly in ``update()`` methods or combined with

```

```
JAX transformations. These functions provide the same activations as the layer-based equivalents but without state or module overhead.

```
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

"# Call once to trigger compilation\n",
"test_input = jnp.array([1.0, 2.0])\n",
"_ = simple_vmap(test_input)\n",
"_ = simple_vmap2(test_input)\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Likely typo: calling simple_vmap2 instead of the defined simple_vmap.

simple_vmap is defined as simple_vmap = vmap2(...) above, but simple_vmap2 is not defined, so this call will raise a NameError. This line should call simple_vmap(test_input) instead.

return tuple(rand_vals), tuple(rand_recover_vals)

def __call__(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
def __call__(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Dropping the return type annotation on __call__ reduces type clarity for callers.

The previous signature __call__(...) -> Tuple[Any, Tuple[State, ...]] provided useful information for type checkers and IDEs. Unless it was incorrect, please restore it; if it was wrong, update it to the correct return type instead of removing it entirely.

Suggested change
def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:

Comment on lines 1213 to +1216
"\n",
"\n",
"# Compose: jit -> grad -> vmap\n",
"batched_grad = vmap(\n",
"batched_grad = vmap2(\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick (typo): Comment still refers to vmap while the tutorial is now consistently about vmap2.

This line still mentions vmap. Please update the comment to # Compose: jit -> grad -> vmap2 to match the updated API and avoid confusion.

Suggested change
"\n",
"\n",
"# Compose: jit -> grad -> vmap\n",
"batched_grad = vmap(\n",
"batched_grad = vmap2(\n",
"\n",
"\n",
"# Compose: jit -> grad -> vmap2\n",
"batched_grad = vmap2(\n",


Functional (non-module) activation functions for flexible composition. These are
pure functions that can be used directly in ``update()`` methods or combined with
JAX transformations. Provides the same activations as the layer-based equivalents
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (typo): Fix the sentence fragment starting with "Provides" to have an explicit subject.

In the phrase "JAX transformations. Provides the same activations as the layer-based equivalents but without state or module overhead.", the second sentence is a fragment. Consider revising to something like "They provide the same activations as the layer-based equivalents but without state or module overhead." or "These functions provide the same activations…" so the sentence has an explicit subject and correct agreement.

Suggested implementation:

Functional (non-module) activation functions for flexible composition. These are
pure functions that can be used directly in ``update()`` methods or combined with

JAX transformations. These functions provide the same activations as the layer-based equivalents but without state or module overhead.

@sourcery-ai sourcery-ai bot changed the title Doc Update vectorization docs for vmap2 and relax mapping return type Jan 2, 2026
@chaoming0625 chaoming0625 merged commit 3386c16 into main Jan 2, 2026
6 checks passed
@chaoming0625 chaoming0625 deleted the doc branch January 2, 2026 02:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants