Conversation
…entation - Add BoxCartesianDistribution, BoxCartesianPFEstimator/PBEstimator, and corresponding MLP modules in box_utils.py (simpler, faster alternative to the polar-coordinate approach) - Update Box.is_action_valid for Cartesian per-dimension semantics - Add comprehensive tests for new Cartesian components in test_box_utils.py - Rewrite train_box.py to use Cartesian approach with plotting - Preserve original training script as train_box_legacy.py - Update benchmark runner to use Cartesian modules - Update test fixtures (debug flag, import changes) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Rename Box -> BoxPolar in box.py, restore polar norm-based is_action_valid - Create BoxCartesian(BoxPolar) in box_cartesian.py, overrides only is_action_valid - Add Box = BoxCartesian alias in __init__.py for backward compatibility - Split box_utils.py into box_cartesian_utils.py + box_polar_utils.py - Reduce box_utils.py to backward-compat re-export shim - Split test_box_utils.py into test_box_cartesian_utils.py + test_box_polar_utils.py - Update train_box_legacy.py to import BoxPolar explicitly Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Resolve merge conflicts between local (BoxPolar/BoxCartesian split) and remote (Cartesian refactor). Restore master's polar norm-based validation in BoxPolar since it reimplements the published paper. Polar helper improvements (atan2, no double-precision, compile-friendly BoxPFMLP) are retained in box_polar_utils.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
# Conflicts: # testing/test_samplers_and_containers.py
- Bug 2: Separate numerical epsilon (1e-6) from env.epsilon (1e-10) in both PF and PB estimators to prevent Jacobian explosions near boundaries - Bug 3: Boundary exit now uses learned Bernoulli log_prob instead of forced 0, giving gradient signal to prefer high-reward exits - Bug 1: Zero out log_no_exit for s0 states so first-step log P_F is not systematically underestimated - Bug 7: Normalize MLP inputs [0,1] -> [-1,1] to make reward modes symmetric (x~0.15 and x~0.85 map to ±0.7) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ncentration - Bug 6: Raise max_concentration default from 5.0 to 100.0 in both PF and PB estimators (matches gflownet reference beta_params_max) - Bug 5: Add learned Bernoulli BTS probability to BoxCartesianPBDistribution. Previously BTS was always deterministic (log_prob=0), giving P_B zero gradient from the BTS step in every trajectory. For 1-step trajectories (s0 -> x_T -> BTS), P_B was entirely gradient-free. Now log P_B(BTS|x) uses the learned Bernoulli, closing the TB loop for all trajectory lengths. BoxCartesianPBMLP gains 1 extra output (bts_logit at position 0). - Bug 4: BTS detection tolerance increased from delta*0.1 to delta*0.5 (subsumed by Bug 5 implementation) - Add detailed docstring explaining why the learned BTS Bernoulli is critical - Relax test_subTB_vs_TB tolerance from 1e-4 to 1e-3: non-zero PB log_probs cause slightly more float32 accumulation error (identity still holds) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Temperature annealing (box_cartesian_utils.py): - Add temperature param to BoxCartesianDistribution and PBDistribution; divides exit/BTS/mixture logits before constructing Bernoulli/Categorical, increasing entropy at T>1 for exploration. - Add mutable temperature attribute to both estimators so train_box.py can update it each iteration without re-instantiating distributions. Training improvements (train_box.py): - Fix LR schedule: scheduler_milestone=0 now computes n_iterations//3 dynamically (3 halvings total), replacing hardcoded 2500 that fired 9x and killed the LR before 10% of training was done. - Add --temperature_start default 2.0: linearly anneals from T_start to 1.0 over the first half of training. Mirrors gflownet random_action_prob. Test tolerance (test_parametrizations_and_losses.py): - Switch to relative tolerance 1e-5 for TB=SubTB identity check. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The test_log_prob_exit_at_boundary test still expected log_prob=0 for boundary exits, but the Bug 3 fix changed these to use the learned Bernoulli. Updated to assert finite valid log probabilities instead. Also fixed pre-existing isort issues in benchmark runner files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
BoxArgs was missing the temperature_start field added in train_box.py, causing AttributeError in CI. Updated JSD targets to match improved convergence from recent bug fixes (~0.09-0.10 across all configs). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
np.float_ was removed in NumPy 2.0. np.float64 is the equivalent type and works across all numpy versions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Simplify box
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add per-phase timing (sample/loss/backward/optimizer) to all benchmark runners with phase breakdown display and JSON output - Add parameter count logging to catch architecture mismatches - Fix MLP n_hidden_layers off-by-one: n_hidden_layers=2 now creates 3 Linear layers (was 4), matching standard convention - Add uniform backward policy option for fair box benchmark comparison - Optimize box distributions: replace Bernoulli/MixtureSameFamily with inline logsigmoid/logsumexp+Beta operations - Optimize States: add _make_view() fast constructor, lazy boolean mask caching, pre-allocated dummy Actions tensors - Fix flaky test: guard against empty transitions in legacy PB helper - Add benchmark analysis document with cross-library comparison Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Run all scenarios by default when no --scenario is supplied - Add --batch-sizes flag (default: 32, 256) to test multiple batch sizes - Standardize all scenarios to 1000 iterations - Save combined CSV of all results for easy analysis with pandas - Require CUDA for benchmark runs - Fix torchgfn Ising to use TB loss (was FM) to match gfnx - Fix gfnx Ising coupling matrix (was zeros) by building J with periodic BCs Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Implement non-autoregressive BitSequence environment where actions encode (position, word) pairs and positions can be filled in any order, matching the gfnx formulation. Update benchmark runner to use it and add training example that verifies learning (L1 < 0.002 on small configs). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Remove gradient clipping from torchgfn runner (gfnx/gflownet don't clip). Match gflownet's box environment parameters: delta=0.1 (was 0.25) for similar trajectory lengths, and concentration range [0.1, 100.0] (was [0.1, 5.1]) for matching Beta distribution expressivity. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Switch from jax.block_until_ready() (array-level) to jax.devices()[0].synchronize() (device-level) to match the torch.cuda.synchronize() semantics used by the PyTorch runners. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Monkeypatch Grid.states2policy in the benchmark runner to add device=self.device to torch.arange() calls, fixing a RuntimeError when benchmarking on CUDA. Applied in the runner rather than the submodule to keep upstream code unmodified. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Remove try/except around benchmark runs so errors surface immediately and the script exits, rather than printing the error and moving on. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…nize Use device.synchronize() when available, fall back to jax.block_until_ready() on train state arrays for older JAX versions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Use --force-reinstall for jax[cuda] to ensure the CUDA plugin is installed even if jax was previously installed without it. Add a post-install check that warns if JAX cannot see a GPU. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add --cpu flag to allow explicit CPU benchmarking. Without it, CUDA is required. In the gfnx runner, verify JAX sees a GPU when PyTorch has CUDA — crash immediately with a clear message if JAX fell back to CPU. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add 2048 to default batch sizes to measure JAX scaling at higher throughput. Document the extra computation gfnx performs during trajectory rollout (entropy, redundant softmax, double forward pass, trajectory transpose) and its impact on batch size scaling. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Lower from 1000 iterations to 100 for faster default runs. The --n-iterations flag can still override for longer benchmarks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
eqx.filter_jit treats Python ints as static values, causing a fresh JIT compilation for every unique idx. With 50 warmup + 100 timed iterations, this meant 150 compilations instead of 1. Converting idx to jnp.array() makes it a traced (dynamic) value so the function is compiled once and reused for all iterations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Now that idx is passed as a JAX array, the int annotation is incorrect and triggers Pyright errors. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
MLPPolicy was a plain Python class, invisible to JAX's pytree system. This meant eqx.partition/block_until_ready couldn't see model arrays, so synchronization wasn't actually waiting for computation to finish (explaining the unrealistic 2ms iteration times). Parameter counting also returned 0. Converting to eqx.Module fixes both issues. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Detail five performance bottlenecks in the gflownet library that cause 45x slowdown from batch_size=32 to 2048: expand-and-compare action indexing, per-element dict lookups, dense one-hot allocation, O(n²) equality checks, and Python-level batch iteration. Include file paths and line numbers for library developer reference. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Patch Cube CUDA device mismatches (monkeypatch torch.zeros/ones on GPU) - Skip gflownet on Ising 10x10 with batch_size > 128 (64s/iter infeasible) - Add --patch-gflownet-actions flag for O(batch) action indexing optimization - Update batch sizes from [32, 256, 1024] to [32, 128, 512] - Add GPU benchmark results and Ising cubic scaling analysis to ANALYSIS.md - Document pinned submodule commits in README.md - Add collect_results.py to amalgamate JSON outputs into a single CSV - Accept **kwargs in runner __init__ for forward compatibility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds a new benchmark/ suite to compare torchgfn against gflownet and gfnx, while introducing new/renamed environments (BoxPolar/BoxCartesian and a non-autoregressive BitSequence) plus a set of performance-focused internal optimizations.
Changes:
- Introduces an end-to-end benchmarking framework (runners, scenarios, setup scripts, analysis, result collection).
- Adds new environments/utilities:
BoxCartesian,NonAutoregressiveBitSequence, and Cartesian Box estimators/distributions; renames legacy Box toBoxPolarwith a compatibility alias. - Implements several hot-path performance optimizations (States slicing/views + caching, sampler preallocation, conditions propagation, TB weighting numeric consistency).
Reviewed changes
Copilot reviewed 44 out of 47 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tutorials/examples/train_box_legacy.py | New legacy training script targeting BoxPolar (old semantics). |
| tutorials/examples/train_box.py | Updates Box training to Cartesian estimators, adds temperature annealing, LR milestone logic, and trajectory plotting. |
| tutorials/examples/train_bitsequence_non_autoregressive.py | New example script for the non-autoregressive BitSequence environment. |
| tutorials/examples/test_scripts.py | Updates Box script args/tests to include temperature_start and new JSD targets. |
| tutorials/examples/output/.gitignore | Ignores generated tutorial outputs (plots/checkpoints). |
| testing/test_samplers_and_containers.py | Switches sampler/container tests from Box to BoxPolar + imports moved utils. |
| testing/test_probability_calculations.py | Guards PB transition log-prob computation when no valid next states exist. |
| testing/test_parametrizations_and_losses.py | Updates Box parametrization tests to Cartesian estimators/modules + seeding. |
| testing/test_gflownet.py | Updates GFlowNet generic tests to Cartesian Box estimators/modules. |
| testing/test_environments.py | Reworks Box forward-step tests to validate Cartesian semantics. |
| src/gfn/utils/modules.py | Fixes/clarifies MLP hidden-layer semantics and adjusts noisy-layer logic + assertions. |
| src/gfn/states.py | Adds fast _make_view constructor, caches is_initial_state/is_sink_state, and optimizes slicing. |
| src/gfn/samplers.py | Optimizes trajectory sampling loop (prealloc actions, bypass conditions setter, debug-only storage checks). |
| src/gfn/gym/tests/test_box_utils.py | Removes old Box util tests (polar-focused). |
| src/gfn/gym/tests/test_box_polar_utils.py | Adds polar utils tests and additional Cartesian coverage. |
| src/gfn/gym/tests/test_box_cartesian_utils.py | Adds dedicated test suite for Cartesian Box distribution/estimators. |
| src/gfn/gym/helpers/box_cartesian_utils.py | New Cartesian Box estimators, MLP heads, and forward/backward distributions. |
| src/gfn/gym/helpers/bayesian_structure/jsd.py | Replaces deprecated np.float_ with explicit np.float64. |
| src/gfn/gym/box_cartesian.py | New BoxCartesian env with per-dimension action validation. |
| src/gfn/gym/box.py | Renames legacy Box to BoxPolar and documents polar validation semantics. |
| src/gfn/gym/bitSequenceNonAutoregressive.py | New NonAutoregressiveBitSequence discrete environment implementation. |
| src/gfn/gym/init.py | Exports new envs and sets Box = BoxCartesian for backward compatibility. |
| src/gfn/gflownet/sub_trajectory_balance.py | Adds TB-equivalent fast path to match TBGFlowNet numeric behavior and avoid cumsum drift. |
| src/gfn/env.py | Skips conditions setter in internal step/backward_step to reduce overhead. |
| pyproject.toml | Disables optional (None) handling reports in pyright. |
| benchmark/setup_benchmark.sh | New setup helper to install benchmark dependencies across platforms. |
| benchmark/profile_hypergrid.py | New profiling script for HyperGrid iteration breakdown. |
| benchmark/profile_box.py | New profiling script for Box iteration breakdown. |
| benchmark/outputs/.gitkeep | Ensures benchmark outputs directory exists in git. |
| benchmark/outputs/.gitignore | Ignores benchmark result artifacts (CSV/JSON). |
| benchmark/lib_runners/torchgfn_runner.py | Runner for benchmarking torchgfn across supported envs. |
| benchmark/lib_runners/gflownet_runner.py | Runner for benchmarking upstream gflownet (Hydra compose + monkeypatches). |
| benchmark/lib_runners/base.py | Base benchmark dataclasses, runner interface, and result utilities. |
| benchmark/lib_runners/init.py | Exposes benchmark runner types. |
| benchmark/gfnx | Adds gfnx as a git submodule pin. |
| benchmark/gflownet | Adds gflownet as a git submodule pin. |
| benchmark/collect_results.py | Collects benchmark JSON outputs into a combined CSV. |
| benchmark/benchmark_libraries.py | Main benchmark entrypoint: scenarios, timing, aggregation, outputs. |
| benchmark/README.md | Benchmark documentation (setup, scenarios, usage, output format). |
| benchmark/ANALYSIS.md | Detailed performance analysis and reproducibility notes. |
| .gitmodules | Declares benchmark submodules. |
| .gitignore | Ignores .claude/ and fixes .lprof formatting. |
Comments suppressed due to low confidence (3)
src/gfn/utils/modules.py:1
- The noisy-layer allocation math can create an extra hidden-to-hidden layer when
n_hidden_layers=1andn_noisy_layers>=2(becausen_extra_hidden=0butn_noisy_hidden_layers=1). This changes the network depth and parameter count unexpectedly. Fix by ensuringn_noisy_hidden_layersis capped byn_extra_hiddenand by accounting separately for whether the input projection and output layer are noisy (e.g., reserve noise for the output first, then allocate any remaining noisy layers to the extra hidden-to-hidden layers only).
tutorials/examples/train_bitsequence_non_autoregressive.py:1 - This counts terminal-state frequencies with an O(n_states * n_samples) loop and full-tensor comparisons per state. For the documented 'feasible' case (e.g.
n_states=4096,n_samples≈409600), this becomes prohibitively large. A faster approach is to encode each terminal state into a unique integer key (e.g., base-n_wordspositional encoding) and then usetorch.bincount(orunique(return_counts=True)) to compute counts in ~O(n_samples).
tutorials/examples/train_box.py:1 - The comment says the default fires at 33%, 67%, and 100%, but the milestone list uses
int(n_iterations / milestone), which typically excludes the final step (e.g., withmilestone = n_iterations//3, this yields milestones near 33%, 67%, and ~99% depending on divisibility). Either adjust the milestone construction to match the documented behavior, or update the comment to reflect the actual milestone positions.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| non_exit_actions = actions[~actions.is_exit] | ||
| non_terminal_states = states[~actions.is_exit] | ||
|
|
||
| if len(non_exit_actions) == 0: |
There was a problem hiding this comment.
is_action_valid() returns True when all actions are exits, without checking whether exit is valid from the corresponding states (notably: exiting from s0 should be invalid per the Cartesian distribution logic). Consider explicitly validating exit actions before early-returning (e.g., reject actions.is_exit & states.is_initial_state for forward steps; and apply any other exit-availability rules the env intends to enforce).
| if len(non_exit_actions) == 0: | |
| if len(non_exit_actions) == 0: | |
| # All actions are exits. For forward steps, exiting from s0 is invalid | |
| # under the Cartesian semantics, so reject those cases explicitly. | |
| if not backward: | |
| if torch.any(states.is_initial_state & actions.is_exit): | |
| return False |
| def log_prob(self, actions: Tensor) -> Tensor: | ||
| """Compute log probability using Cartesian per-dimension approach.""" | ||
| # Identify exit actions | ||
| is_exit = torch.all(actions == float("-inf"), dim=-1) | ||
|
|
||
| # For non-exit: replace -inf with valid placeholder to avoid NaN in computation | ||
| if is_exit.any(): | ||
| safe_actions = torch.where( | ||
| is_exit.unsqueeze(-1), | ||
| self.min_incr + 0.5 * self.max_range, | ||
| actions, | ||
| ) | ||
| else: | ||
| safe_actions = actions | ||
|
|
||
| # Convert absolute to relative: r = (action - min_incr) / max_range | ||
| safe_max_range = self.max_range.clamp(min=self.epsilon) | ||
| r = (safe_actions - self.min_incr) / safe_max_range | ||
| r = r.clamp(self.epsilon, 1 - self.epsilon) |
There was a problem hiding this comment.
BoxCartesianDistribution.log_prob() clamps r into (eps, 1-eps) without enforcing the action support, so out-of-range actions (e.g., < delta for non-s0, or > 1 - state) can receive finite log-prob instead of -inf. This breaks the expected 'probability 0 outside support' invariant and can hide invalid-action issues. Suggest adding an explicit validity mask based on state (s0 vs non-s0) and setting log_probs to -inf for non-exit actions that violate per-dimension bounds.
| class NonAutoregressiveBitSequence(DiscreteEnv): | ||
| """Non-autoregressive BitSequence environment. |
There was a problem hiding this comment.
This PR introduces a new public environment (NonAutoregressiveBitSequence) but doesn't add a corresponding test module validating core invariants (forward/backward mask correctness, step/backward_step roundtrip, exit-only-at-terminal behavior, and reward monotonicity vs Hamming distance). Adding targeted pytest coverage would help prevent regressions in action encoding/masking and terminal/exit semantics.
Description
Via the
benchmarkfolder, runs an extensive benchmark of our library againstgflownetandgfnx.The benchmark, after 50 iterations of warm up, times 100 training iterations for each library, for each environment, for each batch size.
Due to the trouble scaling gflownet batch sizes, we had to cap our max batch sizes. However, if this is patched in their project, all of these benchmarks can be re-run.
This required me to add new environments to our library to correctly compare with the other libraries in a few cases:
gflownetlibrary, which is a standard diffusion-style environment.gfnx, which is somewhat similar to a discrete diffusion environment (instead of a sequence sampling environment).Misc changes: