Skip to content

Add a custom path for matrix multiplication to leverage fast kernel library#1

Open
Shukla-Gaurav wants to merge 5 commits intomainfrom
gaurav/mips_matmul_flow
Open

Add a custom path for matrix multiplication to leverage fast kernel library#1
Shukla-Gaurav wants to merge 5 commits intomainfrom
gaurav/mips_matmul_flow

Conversation

@Shukla-Gaurav
Copy link
Collaborator

No description provided.

Gaurav Shukla added 5 commits March 3, 2026 15:09
Introduces a new MIPS dialect that acts as a semantic abstraction layer
for hardware-specific matrix multiply operations inside the IREE compiler.

Key components:

- IR/MIPSBase.td / MIPSDialect.h/.cpp: dialect definition, namespace
  ::mlir::iree_compiler::IREE::MIPS, dependent on func/memref/tensor.

- IR/MIPSOps.td / MIPSOps.h/.cpp: mips.matmul op, tensor-only,
  Destination-Passing-Style (DPS). Verifier checks 2-D tensor shapes
  (lhs[MxK], rhs[KxN], init[MxN]) and element-type consistency.
  Implements ReifyRankedShapedTypeOpInterface and MemoryEffectsOpInterface.

- IR/MIPSBufferizableOpInterface.cpp: eliminates mips.matmul entirely
  during One-Shot Bufferize by emitting func.call @my_matmul_kernel
  directly with the memref (base_ptr, offset, stride0, stride1) ABI.
  Uses memref.memory_space_cast to strip IREE HAL memory spaces from
  the base pointers so the function declaration stays stable across
  eraseHALDescriptorTypeFromMemRef.

- Transforms/Passes.td/.h/.cpp: pass registry for
  LowerMIPSToFuncCallPass (now a no-op; kept for pipeline compatibility
  since bufferization handles the lowering directly).

- Transforms/LowerMIPSToFuncCall.cpp: no-op pass stub.

- DispatchCreation/FormDispatchRegions.cpp: teach IREE's dispatch
  formation to treat mips.matmul as a compute-heavy op eligible for
  outlining into flow.dispatch.workgroups.

- Tools/init_iree_dialects.h: register MIPSDialect and its
  BufferizableOpInterface external models.

- Tools/init_iree_passes.h: register MIPS passes with the global
  pass registry.
Adds the frontend conversion from torch.aten.mm to mips.matmul and
wires the MIPS dialect into the IREE LLVMCPU codegen pipeline.

Torch InputConversion changes:

- ConvertTorchToMIPS.cpp: ConvertAtenMmToMIPSMatmul pattern rewrites
  torch.aten.mm to mips.matmul with a zero-initialized init tensor
  (bufferization.alloc_tensor). The pass runs before the standard
  ConvertTorchToLinalgPass so mips.matmul takes precedence over the
  generic linalg.matmul path.

- Passes.td / Passes.h / Passes.cpp: declare ConvertTorchToMIPSPass and
  add it to the torch-to-iree pipeline under the use-mips-matmul option.

- test/convert_torch_to_mips.mlir: FileCheck test verifying that
  torch.aten.mm is replaced by mips.matmul after the pass.

LLVMCPU codegen pipeline changes:

- Passes.cpp: insert LowerMIPSToFuncCallPass (no-op stub) in the
  post-bufferize section of buildLLVMCPUCodegenPassPipeline. The actual
  lowering to func.call is performed during One-Shot Bufferize by
  MIPSBufferizableOpInterface; this stub ensures the pass slot is
  reserved for future use and keeps the pipeline definition explicit.

- CMakeLists.txt: add iree_compiler_Dialect_MIPS_Transforms_Transforms
  dependency to the LLVMCPU codegen target.
Provides the runtime kernel that backs mips.matmul dispatches.

The kernel is packaged as an IREE executable plugin (implements
iree_hal_executable_plugin_query) rather than a plain shared library
because IREE's LLVMCPU system-dylib dispatch format resolves external
function references through an internal import table (not standard ELF
dynamic linking). The plugin's resolve() callback maps the symbol name
"my_matmul_kernel" to the import-ABI wrapper.

Usage:
  iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ...
Provides two convenience scripts for exercising the full
torch.aten.mm → mips.matmul → func.call → vmfb → kernel plugin pipeline.

compile_mips.zsh:
  Step 1 — verifies torch.aten.mm is converted to mips.matmul by the
            ConvertTorchToMIPSPass (iree-opt smoke check).
  Step 2 — runs the full torch-to-iree pipeline with use-mips-matmul=true,
            producing the IREE input IR (/tmp/mm_iree.mlir).
  Step 3 — compiles the IREE input IR to a vmfb (/tmp/mm_mips.vmfb)
            with --mlir-print-ir-after-all, writing a per-pass IR dump
            to /tmp/mm_mips_ir_dump.mlir for debugging.

run_mips.zsh:
  Runs iree-run-module with --executable_plugin pointing at the
  built libmy_matmul_kernel.dylib. Tests A * I = A (matrix multiplied
  by identity) and prints the result for visual verification.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant