-
Notifications
You must be signed in to change notification settings - Fork 2
Update vectorization docs for vmap2 and relax mapping return type #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Reviewer's GuideUpdates the vectorization tutorial and references to consistently document Class diagram for StatefulMapping and vmap2 after call signature changeclassDiagram
class StatefulMapping {
+__call__(*args, **kwargs) Any
}
class vmap2 {
}
vmap2 --> StatefulMapping : wraps
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
@sourcery-ai title |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - I've found 4 issues, and left some high level feedback:
- In the tutorial notebook you renamed the imported symbol to
vmap2but left some call sites inconsistent (e.g.,simple_vmapis defined but later_ = simple_vmap2(test_input)is called), which will raise a NameError when executing the notebook; please align the variable names with the imported function name throughout. - You removed the return type annotation from
StatefulMapping.__call__in_mapping2.py; if the method still always returns a(output, states)tuple, keeping or updating a precise return type would help static checking and IDE support.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In the tutorial notebook you renamed the imported symbol to `vmap2` but left some call sites inconsistent (e.g., `simple_vmap` is defined but later `_ = simple_vmap2(test_input)` is called), which will raise a NameError when executing the notebook; please align the variable names with the imported function name throughout.
- You removed the return type annotation from `StatefulMapping.__call__` in `_mapping2.py`; if the method still always returns a `(output, states)` tuple, keeping or updating a precise return type would help static checking and IDE support.
## Individual Comments
### Comment 1
<location> `docs/tutorials/transforms/03_vectorization.ipynb:1105` </location>
<code_context>
"# Call once to trigger compilation\n",
"test_input = jnp.array([1.0, 2.0])\n",
- "_ = simple_vmap(test_input)\n",
+ "_ = simple_vmap2(test_input)\n",
"\n",
"# Access the compiled Jaxpr\n",
</code_context>
<issue_to_address>
**issue (bug_risk):** Likely typo: calling `simple_vmap2` instead of the defined `simple_vmap`.
`simple_vmap` is defined as `simple_vmap = vmap2(...)` above, but `simple_vmap2` is not defined, so this call will raise a `NameError`. This line should call `simple_vmap(test_input)` instead.
</issue_to_address>
### Comment 2
<location> `brainstate/transform/_mapping2.py:518` </location>
<code_context>
return tuple(rand_vals), tuple(rand_recover_vals)
- def __call__(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
+ def __call__(self, *args, **kwargs):
"""Execute the stateful mapping on the given arguments.
</code_context>
<issue_to_address>
**suggestion:** Dropping the return type annotation on `__call__` reduces type clarity for callers.
The previous signature `__call__(...) -> Tuple[Any, Tuple[State, ...]]` provided useful information for type checkers and IDEs. Unless it was incorrect, please restore it; if it was wrong, update it to the correct return type instead of removing it entirely.
```suggestion
def __call__(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]:
```
</issue_to_address>
### Comment 3
<location> `docs/tutorials/transforms/03_vectorization.ipynb:1213-1216` </location>
<code_context>
"\n",
"\n",
"# Compose: jit -> grad -> vmap\n",
- "batched_grad = vmap(\n",
+ "batched_grad = vmap2(\n",
</code_context>
<issue_to_address>
**nitpick (typo):** Comment still refers to `vmap` while the tutorial is now consistently about `vmap2`.
This line still mentions `vmap`. Please update the comment to `# Compose: jit -> grad -> vmap2` to match the updated API and avoid confusion.
```suggestion
"\n",
"\n",
"# Compose: jit -> grad -> vmap2\n",
"batched_grad = vmap2(\n",
```
</issue_to_address>
### Comment 4
<location> `docs/apis/nn.rst:360` </location>
<code_context>
+Functional (non-module) activation functions for flexible composition. These are
+pure functions that can be used directly in ``update()`` methods or combined with
+JAX transformations. Provides the same activations as the layer-based equivalents
+but without state or module overhead.
+
</code_context>
<issue_to_address>
**issue (typo):** Fix the sentence fragment starting with "Provides" to have an explicit subject.
In the phrase "JAX transformations. Provides the same activations as the layer-based equivalents but without state or module overhead.", the second sentence is a fragment. Consider revising to something like "They provide the same activations as the layer-based equivalents but without state or module overhead." or "These functions provide the same activations…" so the sentence has an explicit subject and correct agreement.
Suggested implementation:
```
Functional (non-module) activation functions for flexible composition. These are
pure functions that can be used directly in ``update()`` methods or combined with
```
```
JAX transformations. These functions provide the same activations as the layer-based equivalents but without state or module overhead.
```
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| "# Call once to trigger compilation\n", | ||
| "test_input = jnp.array([1.0, 2.0])\n", | ||
| "_ = simple_vmap(test_input)\n", | ||
| "_ = simple_vmap2(test_input)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Likely typo: calling simple_vmap2 instead of the defined simple_vmap.
simple_vmap is defined as simple_vmap = vmap2(...) above, but simple_vmap2 is not defined, so this call will raise a NameError. This line should call simple_vmap(test_input) instead.
| return tuple(rand_vals), tuple(rand_recover_vals) | ||
|
|
||
| def __call__(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]: | ||
| def __call__(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Dropping the return type annotation on __call__ reduces type clarity for callers.
The previous signature __call__(...) -> Tuple[Any, Tuple[State, ...]] provided useful information for type checkers and IDEs. Unless it was incorrect, please restore it; if it was wrong, update it to the correct return type instead of removing it entirely.
| def __call__(self, *args, **kwargs): | |
| def __call__(self, *args, **kwargs) -> Tuple[Any, Tuple[State, ...]]: |
| "\n", | ||
| "\n", | ||
| "# Compose: jit -> grad -> vmap\n", | ||
| "batched_grad = vmap(\n", | ||
| "batched_grad = vmap2(\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick (typo): Comment still refers to vmap while the tutorial is now consistently about vmap2.
This line still mentions vmap. Please update the comment to # Compose: jit -> grad -> vmap2 to match the updated API and avoid confusion.
| "\n", | |
| "\n", | |
| "# Compose: jit -> grad -> vmap\n", | |
| "batched_grad = vmap(\n", | |
| "batched_grad = vmap2(\n", | |
| "\n", | |
| "\n", | |
| "# Compose: jit -> grad -> vmap2\n", | |
| "batched_grad = vmap2(\n", |
|
|
||
| Functional (non-module) activation functions for flexible composition. These are | ||
| pure functions that can be used directly in ``update()`` methods or combined with | ||
| JAX transformations. Provides the same activations as the layer-based equivalents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (typo): Fix the sentence fragment starting with "Provides" to have an explicit subject.
In the phrase "JAX transformations. Provides the same activations as the layer-based equivalents but without state or module overhead.", the second sentence is a fragment. Consider revising to something like "They provide the same activations as the layer-based equivalents but without state or module overhead." or "These functions provide the same activations…" so the sentence has an explicit subject and correct agreement.
Suggested implementation:
Functional (non-module) activation functions for flexible composition. These are
pure functions that can be used directly in ``update()`` methods or combined with
JAX transformations. These functions provide the same activations as the layer-based equivalents but without state or module overhead.
Summary by Sourcery
Update vectorization tutorial and mapping API to reflect the
vmap2transform instead ofvmap.Enhancements:
__call__method return type in the stateful mapping implementation to be unconstrained for greater flexibility.Documentation:
vmap2, updating narrative text, headings, and code examples accordingly.