Skip to content

Improve performance of dhconv contraction#861

Open
mcgibbon wants to merge 15 commits intomainfrom
feature/dhconv_optimization
Open

Improve performance of dhconv contraction#861
mcgibbon wants to merge 15 commits intomainfrom
feature/dhconv_optimization

Conversation

@mcgibbon
Copy link
Contributor

@mcgibbon mcgibbon commented Feb 20, 2026

This PR changes the shape of the weight matrix in SpectralConvS2 from [group, in, out, lat] to [group, lat, out, in] so that the matrix multiplication dimensions [out, in] are the fast dimensions, with [in] being the fastest. Backwards compatibility is included for existing checkpoints.

Changes:

  • No public API changes

  • Tests added

@mcgibbon
Copy link
Contributor Author

mcgibbon commented Feb 20, 2026

Commits are done such that there's verification the model produces identical behavior. I made a separate commit updating the regression target which only trivially changes the weight initialization (making it cleaner, but breaking regression).

Still need to write a test that ensures previous checkpoints have their behavior unchanged, but it should be implemented/working using the same method we use to support old checkpoints without a group dim.

Before and after (note the change to dhconv and the total):
image
image
Note that with 8 groups, we no longer see an improvement as the runtime is already low (15ms on T4).

@mcgibbon mcgibbon changed the title Feature/dhconv optimization Improve performance of S2Convolution contraction Feb 24, 2026
@mcgibbon mcgibbon changed the title Improve performance of S2Convolution contraction Improve performance of dhconv contraction Feb 24, 2026
@mcgibbon
Copy link
Contributor Author

mcgibbon commented Feb 24, 2026

When reviewing, please check that the regression file updates are valid by looking per-commit. The weight shape refactor is done first in a way that maintains backwards compatibility and passes regression tests, with the regression targets only updated afterwards with a trivial change to the way the weights are initialized (which changes the values they get initialized to).

I added a checkpoint-load regression test in another PR, so this is no longer an issue.

@mcgibbon
Copy link
Contributor Author

I added a regression test for the output from a loaded checkpoint being the same, using an input checkpoint generated from main. The test caught a bug, so it's working properly.

@mcgibbon mcgibbon marked this pull request as ready for review February 24, 2026 15:47
@mcgibbon mcgibbon changed the base branch from main to feature/csfno_checkpoint_regression February 24, 2026 15:56
Base automatically changed from feature/csfno_checkpoint_regression to main February 24, 2026 19:35
mcgibbon added a commit that referenced this pull request Feb 24, 2026
These changes are used to test performance improvements in
#861

Changes:
- Added `asdict` and `from_dict` methods to Context
- Added regression test which ensures the output of the conditional sfno
after loading checkpoint data from disk is unchanged

- [x] Tests added

---------

Co-authored-by: W. Andre Perkins <frodre@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant