Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR fixes issues in the QMoE CPU operator implementation, specifically correcting bias handling logic and updating attribute naming to match the actual C++ implementation. The changes also improve MLAS documentation for better clarity on input buffer layout requirements.
Changes:
- Fixed FC2 bias handling in QMoE CPU operator by tracking when MLAS DirectQ4Gemm adds bias
- Added transpose logic to convert weight matrices from [N, K] to [K, N] layout required by MlasQ4GemmPackB
- Updated Python tests to use
swiglu_fusionattribute instead of incorrectswiglu_interleavedattribute - Enhanced MLAS documentation to clarify that MlasQ4GemmPackB expects FpData with shape [K, N]
- Added proper bias collection and passing in Python test infrastructure
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| onnxruntime/core/mlas/inc/mlas_q4.h | Updated documentation for MlasQ4GemmPackB to clarify FpData shape [K, N] and parameter meanings |
| onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc | Added transpose logic for weight matrices, renamed fc2_bias_handled_by_q4_gemm to fc2_bias_added_by_mlas, removed unused fc1_used_direct_q4 flag |
| onnxruntime/test/python/transformers/test_qmoe_cpu.py | Migrated from swiglu_interleaved to swiglu_fusion attribute, added bias collection/passing logic, updated swiglu function signature, improved weight interleaving for swiglu_fusion=1 |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR addresses several issues in the QMoE CPU implementation, improves MLAS documentation.
Changes
1. QMoE CPU Operator Fixes
fc2_bias_handled_by_q4_gemmtofc2_bias_added_by_mlasand updated the logic to consistently track whether FC2 bias has been applied. This ensures that bias is not double-counted or missed when usingDirectQ4Gemm.swiglu_interleavedtoswiglu_fusionin both the C++ operator and the Python test infrastructure to align with the latest QMoE implementation standards.2. MLAS Documentation
MlasQ4GemmPackBto specify that the inputFpDatabuffer expects a shape of[K, N]. This helps prevent layout-related errors in future integrations.3. Test Updates
onnxruntime/test/python/transformers/test_qmoe_cpu.pyto useswiglu_fusionand improved the test structure for better parity checks with PyTorch.Verification
test_qmoe_cpu.pyto ensure all QMoE parity tests pass on CPU.