Add a custom path for matrix multiplication to leverage fast kernel library#1
Open
Shukla-Gaurav wants to merge 5 commits intomainfrom
Open
Add a custom path for matrix multiplication to leverage fast kernel library#1Shukla-Gaurav wants to merge 5 commits intomainfrom
Shukla-Gaurav wants to merge 5 commits intomainfrom
Conversation
added 5 commits
March 3, 2026 15:09
Introduces a new MIPS dialect that acts as a semantic abstraction layer for hardware-specific matrix multiply operations inside the IREE compiler. Key components: - IR/MIPSBase.td / MIPSDialect.h/.cpp: dialect definition, namespace ::mlir::iree_compiler::IREE::MIPS, dependent on func/memref/tensor. - IR/MIPSOps.td / MIPSOps.h/.cpp: mips.matmul op, tensor-only, Destination-Passing-Style (DPS). Verifier checks 2-D tensor shapes (lhs[MxK], rhs[KxN], init[MxN]) and element-type consistency. Implements ReifyRankedShapedTypeOpInterface and MemoryEffectsOpInterface. - IR/MIPSBufferizableOpInterface.cpp: eliminates mips.matmul entirely during One-Shot Bufferize by emitting func.call @my_matmul_kernel directly with the memref (base_ptr, offset, stride0, stride1) ABI. Uses memref.memory_space_cast to strip IREE HAL memory spaces from the base pointers so the function declaration stays stable across eraseHALDescriptorTypeFromMemRef. - Transforms/Passes.td/.h/.cpp: pass registry for LowerMIPSToFuncCallPass (now a no-op; kept for pipeline compatibility since bufferization handles the lowering directly). - Transforms/LowerMIPSToFuncCall.cpp: no-op pass stub. - DispatchCreation/FormDispatchRegions.cpp: teach IREE's dispatch formation to treat mips.matmul as a compute-heavy op eligible for outlining into flow.dispatch.workgroups. - Tools/init_iree_dialects.h: register MIPSDialect and its BufferizableOpInterface external models. - Tools/init_iree_passes.h: register MIPS passes with the global pass registry.
Adds the frontend conversion from torch.aten.mm to mips.matmul and wires the MIPS dialect into the IREE LLVMCPU codegen pipeline. Torch InputConversion changes: - ConvertTorchToMIPS.cpp: ConvertAtenMmToMIPSMatmul pattern rewrites torch.aten.mm to mips.matmul with a zero-initialized init tensor (bufferization.alloc_tensor). The pass runs before the standard ConvertTorchToLinalgPass so mips.matmul takes precedence over the generic linalg.matmul path. - Passes.td / Passes.h / Passes.cpp: declare ConvertTorchToMIPSPass and add it to the torch-to-iree pipeline under the use-mips-matmul option. - test/convert_torch_to_mips.mlir: FileCheck test verifying that torch.aten.mm is replaced by mips.matmul after the pass. LLVMCPU codegen pipeline changes: - Passes.cpp: insert LowerMIPSToFuncCallPass (no-op stub) in the post-bufferize section of buildLLVMCPUCodegenPassPipeline. The actual lowering to func.call is performed during One-Shot Bufferize by MIPSBufferizableOpInterface; this stub ensures the pass slot is reserved for future use and keeps the pipeline definition explicit. - CMakeLists.txt: add iree_compiler_Dialect_MIPS_Transforms_Transforms dependency to the LLVMCPU codegen target.
Provides the runtime kernel that backs mips.matmul dispatches. The kernel is packaged as an IREE executable plugin (implements iree_hal_executable_plugin_query) rather than a plain shared library because IREE's LLVMCPU system-dylib dispatch format resolves external function references through an internal import table (not standard ELF dynamic linking). The plugin's resolve() callback maps the symbol name "my_matmul_kernel" to the import-ABI wrapper. Usage: iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ...
Provides two convenience scripts for exercising the full
torch.aten.mm → mips.matmul → func.call → vmfb → kernel plugin pipeline.
compile_mips.zsh:
Step 1 — verifies torch.aten.mm is converted to mips.matmul by the
ConvertTorchToMIPSPass (iree-opt smoke check).
Step 2 — runs the full torch-to-iree pipeline with use-mips-matmul=true,
producing the IREE input IR (/tmp/mm_iree.mlir).
Step 3 — compiles the IREE input IR to a vmfb (/tmp/mm_mips.vmfb)
with --mlir-print-ir-after-all, writing a per-pass IR dump
to /tmp/mm_mips_ir_dump.mlir for debugging.
run_mips.zsh:
Runs iree-run-module with --executable_plugin pointing at the
built libmy_matmul_kernel.dylib. Tests A * I = A (matrix multiplied
by identity) and prints the result for visual verification.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.