From 37b40b44dcf48f297721b9367bf2e75bfa055eb6 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 15:09:18 +0530 Subject: [PATCH 1/5] [MIPS] Add MIPS dialect with mips.matmul op Introduces a new MIPS dialect that acts as a semantic abstraction layer for hardware-specific matrix multiply operations inside the IREE compiler. Key components: - IR/MIPSBase.td / MIPSDialect.h/.cpp: dialect definition, namespace ::mlir::iree_compiler::IREE::MIPS, dependent on func/memref/tensor. - IR/MIPSOps.td / MIPSOps.h/.cpp: mips.matmul op, tensor-only, Destination-Passing-Style (DPS). Verifier checks 2-D tensor shapes (lhs[MxK], rhs[KxN], init[MxN]) and element-type consistency. Implements ReifyRankedShapedTypeOpInterface and MemoryEffectsOpInterface. - IR/MIPSBufferizableOpInterface.cpp: eliminates mips.matmul entirely during One-Shot Bufferize by emitting func.call @my_matmul_kernel directly with the memref (base_ptr, offset, stride0, stride1) ABI. Uses memref.memory_space_cast to strip IREE HAL memory spaces from the base pointers so the function declaration stays stable across eraseHALDescriptorTypeFromMemRef. - Transforms/Passes.td/.h/.cpp: pass registry for LowerMIPSToFuncCallPass (now a no-op; kept for pipeline compatibility since bufferization handles the lowering directly). - Transforms/LowerMIPSToFuncCall.cpp: no-op pass stub. - DispatchCreation/FormDispatchRegions.cpp: teach IREE's dispatch formation to treat mips.matmul as a compute-heavy op eligible for outlining into flow.dispatch.workgroups. - Tools/init_iree_dialects.h: register MIPSDialect and its BufferizableOpInterface external models. - Tools/init_iree_passes.h: register MIPS passes with the global pass registry. --- .../iree/compiler/Dialect/MIPS/CMakeLists.txt | 8 + .../compiler/Dialect/MIPS/IR/CMakeLists.txt | 70 ++++++ .../iree/compiler/Dialect/MIPS/IR/MIPSBase.td | 47 ++++ .../MIPS/IR/MIPSBufferizableOpInterface.cpp | 216 ++++++++++++++++++ .../compiler/Dialect/MIPS/IR/MIPSDialect.cpp | 56 +++++ .../compiler/Dialect/MIPS/IR/MIPSDialect.h | 28 +++ .../iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp | 98 ++++++++ .../iree/compiler/Dialect/MIPS/IR/MIPSOps.h | 25 ++ .../iree/compiler/Dialect/MIPS/IR/MIPSOps.td | 71 ++++++ .../Dialect/MIPS/Transforms/CMakeLists.txt | 56 +++++ .../MIPS/Transforms/LowerMIPSToFuncCall.cpp | 36 +++ .../Dialect/MIPS/Transforms/Passes.cpp | 18 ++ .../compiler/Dialect/MIPS/Transforms/Passes.h | 33 +++ .../Dialect/MIPS/Transforms/Passes.td | 49 ++++ .../MIPS/Transforms/test/CMakeLists.txt | 15 ++ .../Transforms/test/lower_to_func_call.mlir | 49 ++++ .../compiler/DispatchCreation/CMakeLists.txt | 1 + .../DispatchCreation/FormDispatchRegions.cpp | 4 + .../iree/compiler/Tools/init_iree_dialects.h | 3 + .../iree/compiler/Tools/init_iree_passes.h | 2 + 20 files changed, 885 insertions(+) create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt create mode 100644 compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir diff --git a/compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt new file mode 100644 index 000000000000..3da5a7a85912 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Recursively picks up IR/ and Transforms/ subdirectories. +iree_add_all_subdirs() diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt new file mode 100644 index 000000000000..632b1ee5b190 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/CMakeLists.txt @@ -0,0 +1,70 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +# ─── Tablegen ──────────────────────────────────────────────────────────────── +# Generate op declarations/definitions AND dialect declarations/definitions +# from a single MIPSOps.td source (same pattern as LinalgExt). + +iree_tablegen_library( + NAME + MIPSOpsIncGen + TD_FILE + "MIPSOps.td" + OUTS + --gen-op-decls MIPSOps.h.inc + --gen-op-defs MIPSOps.cpp.inc + --dialect=mips --gen-dialect-decls MIPSDialect.h.inc + --dialect=mips --gen-dialect-defs MIPSDialect.cpp.inc +) + +# ─── C++ library ───────────────────────────────────────────────────────────── + +iree_cc_library( + NAME + IR + HDRS + "MIPSDialect.h" + "MIPSOps.h" + "MIPSDialect.h.inc" + TEXTUAL_HDRS + "MIPSOps.h.inc" + "MIPSOps.cpp.inc" + SRCS + "MIPSDialect.cpp" + "MIPSDialect.cpp.inc" + "MIPSOps.cpp" + "MIPSBufferizableOpInterface.cpp" + DEPS + ::MIPSOpsIncGen + LLVMSupport + MLIRIR + MLIRSupport + MLIRFuncDialect + MLIRMemRefDialect + MLIRTensorDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRDestinationStyleOpInterface + MLIRInferTypeOpInterface + MLIRSideEffectInterfaces + MLIRTensorUtils + MLIRTransforms + PUBLIC +) + +# ─── Documentation ─────────────────────────────────────────────────────────── + +iree_tablegen_doc( + NAME + MIPSDialectDocGen + CATEGORY "Dialects" + TD_FILE + "MIPSOps.td" + OUTS + --gen-dialect-doc -dialect=mips MIPSDialect.md +) diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td new file mode 100644 index 000000000000..fd006c3bfcd7 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBase.td @@ -0,0 +1,47 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_MIPS_BASE +#define IREE_DIALECT_MIPS_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// MIPS dialect definition +//===----------------------------------------------------------------------===// + +def MIPS_Dialect : Dialect { + let name = "mips"; + let cppNamespace = "::mlir::iree_compiler::IREE::MIPS"; + let summary = "MIPS custom compute dialect for experimental matmul pipeline."; + let description = [{ + The MIPS dialect defines a single `mips.matmul` operation that serves as + a custom intermediate representation between the Torch frontend and the + final `func.call @my_matmul_kernel` generated by the MIPS lowering pass. + + Pipeline: + torch.aten.mm → mips.matmul → func.call @my_matmul_kernel + }]; + + let dependentDialects = [ + "::mlir::func::FuncDialect", + "::mlir::memref::MemRefDialect", + "::mlir::tensor::TensorDialect" + ]; + + // No custom attribute types → do not declare parseAttribute/printAttribute + // overrides. The base Dialect class handles the fallback behavior. + let useDefaultAttributePrinterParser = 0; +} + +//===----------------------------------------------------------------------===// +// Base op class +//===----------------------------------------------------------------------===// + +class MIPS_Op traits = []> + : Op; + +#endif // IREE_DIALECT_MIPS_BASE diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp new file mode 100644 index 000000000000..e5499939bcc3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSBufferizableOpInterface.cpp @@ -0,0 +1,216 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Implements BufferizableOpInterface for mips.matmul. +// +// mips.matmul is a tensor-only, Destination-Passing-Style (DPS) op. It is +// eliminated *entirely* during One-Shot Bufferize: bufferize() obtains memref +// buffers for all three operands, decomposes each 2-D memref into +// (base_ptr, offset, stride0, stride1) via memref.extract_strided_metadata, +// and emits a func.call @my_matmul_kernel directly. No memref form of +// mips.matmul ever exists in the IR. +// +// Before bufferization: +// %C = mips.matmul %A, %B, %init +// : tensor, tensor, tensor -> tensor +// +// After bufferization (produced inside bufferize()): +// %A_meta = memref.extract_strided_metadata %A_buf -> (base, off, s0, s1) +// %B_meta = memref.extract_strided_metadata %B_buf -> (base, off, s0, s1) +// %C_meta = memref.extract_strided_metadata %C_buf -> (base, off, s0, s1) +// %M = memref.dim %A_buf, 0 +// %N = memref.dim %B_buf, 1 +// %K = memref.dim %A_buf, 1 +// call @my_matmul_kernel(%A_base, %A_off, %A_s0, %A_s1, +// %B_base, %B_off, %B_s0, %B_s1, +// %C_base, %C_off, %C_s0, %C_s1, +// %M, %N, %K) +// -- tensor result replaced by %C_buf via replaceOpWithBufferizedValues -- + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace mlir::iree_compiler::IREE::MIPS { +namespace { + +static constexpr StringLiteral kKernelName = "my_matmul_kernel"; + +//===----------------------------------------------------------------------===// +// Helper: ensure func.func private @my_matmul_kernel exists at module scope. +// +// The declaration carries {llvm.bareptr = true} so the LLVM backend passes +// bare float* arguments instead of MLIR memref descriptor structs, matching +// the C kernel ABI. +//===----------------------------------------------------------------------===// + +static func::FuncOp ensureKernelDeclaration(RewriterBase &rewriter, + Operation *moduleOp, + FunctionType fnType, + Location loc) { + if (auto existing = dyn_cast_if_present( + SymbolTable::lookupSymbolIn(moduleOp, kKernelName))) + return existing; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front()); + auto fnDecl = func::FuncOp::create(rewriter, loc, kKernelName, fnType); + SymbolTable::setSymbolVisibility(fnDecl, SymbolTable::Visibility::Private); + fnDecl->setAttr("llvm.bareptr", rewriter.getBoolAttr(true)); + return fnDecl; +} + +//===----------------------------------------------------------------------===// +// Helper: decompose a 2-D memref into (base_ptr, offset, stride0, stride1). +// +// Uses memref.extract_strided_metadata. The base_ptr is always a rank-0 +// memref with DEFAULT address space (memref), regardless of the source +// memref's address space. Any IREE-specific memory space (e.g. +// #hal.descriptor_type) is stripped via +// memref.memory_space_cast so that: +// +// 1. The function declaration uses plain memref, which is stable across +// all pipeline stages. +// 2. eraseHALDescriptorTypeFromMemRefPass (which runs after bufferization and +// does NOT update external function declarations) cannot introduce a +// type mismatch between the call operands and the declaration. +// +// Combined with the {llvm.bareptr = true} attribute on the callee, the +// rank-0 memref lowers to a bare float* matching the C ABI. +//===----------------------------------------------------------------------===// + +static void decomposeMemref2D(RewriterBase &rewriter, Location loc, + Value memref2D, + SmallVectorImpl &callOperands, + SmallVectorImpl &callArgTypes) { + Type indexType = IndexType::get(rewriter.getContext()); + + auto meta = + memref::ExtractStridedMetadataOp::create(rewriter, loc, memref2D); + + // Strip any IREE-specific memory space from the base pointer so the + // function declaration stays in the default address space. + Value basePtr = meta.getBaseBuffer(); + auto basePtrMemrefTy = cast(basePtr.getType()); + MemRefType plainBasePtrTy = + MemRefType::get(/*shape=*/{}, basePtrMemrefTy.getElementType()); + if (basePtrMemrefTy != plainBasePtrTy) { + basePtr = memref::MemorySpaceCastOp::create(rewriter, loc, plainBasePtrTy, + basePtr); + } + + callOperands.push_back(basePtr); + callArgTypes.push_back(plainBasePtrTy); + + callOperands.push_back(meta.getOffset()); + callArgTypes.push_back(indexType); + + for (Value stride : meta.getStrides()) { + callOperands.push_back(stride); + callArgTypes.push_back(indexType); + } +} + +//===----------------------------------------------------------------------===// +// External model — BufferizableOpInterface for mips.matmul. +// +// Inherits from DstBufferizableOpInterfaceExternalModel which automatically +// handles the DPS aliasing (init ↔ result) and write detection for the init +// operand. We override bufferizesToMemoryRead to mark lhs and rhs as read, +// and provide a custom bufferize() that emits func.call @my_matmul_kernel. +//===----------------------------------------------------------------------===// + +struct MIPSMatmulBufferizableOpInterface + : public DstBufferizableOpInterfaceExternalModel< + MIPSMatmulBufferizableOpInterface, MatmulOp> { + + // All three operands are read by the kernel. + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options, + BufferizationState &state) const { + auto matmulOp = cast(op); + Location loc = matmulOp.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(matmulOp); + + // Obtain memref buffers for all three tensor operands. + FailureOr lhsBuf = + getBuffer(rewriter, matmulOp.getLhs(), options, state); + if (failed(lhsBuf)) + return failure(); + FailureOr rhsBuf = + getBuffer(rewriter, matmulOp.getRhs(), options, state); + if (failed(rhsBuf)) + return failure(); + // init aliases with result — one-shot bufferize allocates the output buffer + // (via bufferization.alloc_tensor or in-place analysis) and gives it to us + // here as initBuf. + FailureOr initBuf = + getBuffer(rewriter, matmulOp.getInit(), options, state); + if (failed(initBuf)) + return failure(); + + // Build the flattened argument list for func.call @my_matmul_kernel. + // For each 2-D memref: (base_ptr, offset, stride0, stride1) + // Then: M, N, K as index scalars. + SmallVector callOperands; + SmallVector callArgTypes; + + decomposeMemref2D(rewriter, loc, *lhsBuf, callOperands, callArgTypes); + decomposeMemref2D(rewriter, loc, *rhsBuf, callOperands, callArgTypes); + decomposeMemref2D(rewriter, loc, *initBuf, callOperands, callArgTypes); + + Type indexType = IndexType::get(ctx); + Value M = memref::DimOp::create(rewriter, loc, *lhsBuf, 0); + Value N = memref::DimOp::create(rewriter, loc, *rhsBuf, 1); + Value K = memref::DimOp::create(rewriter, loc, *lhsBuf, 1); + callOperands.append({M, N, K}); + callArgTypes.append(3, indexType); + + // Declare the kernel function in the enclosing module (idempotent). + Operation *moduleOp = SymbolTable::getNearestSymbolTable(matmulOp); + FunctionType fnType = rewriter.getFunctionType(callArgTypes, TypeRange{}); + ensureKernelDeclaration(rewriter, moduleOp, fnType, loc); + + // Emit the call — the kernel writes into *initBuf in place. + func::CallOp::create(rewriter, loc, kKernelName, TypeRange{}, callOperands); + + // Replace the tensor result with the init buffer (DPS aliasing). + replaceOpWithBufferizedValues(rewriter, op, *initBuf); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Public registration entry point +//===----------------------------------------------------------------------===// + +void registerMIPSBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, MIPSDialect * /*dialect*/) { + MatmulOp::attachInterface(*ctx); + }); +} + +} // namespace mlir::iree_compiler::IREE::MIPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp new file mode 100644 index 000000000000..f136f11594fa --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp @@ -0,0 +1,56 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" + +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::iree_compiler::IREE::MIPS; + +//===----------------------------------------------------------------------===// +// Inliner interface — allow MIPS ops to be inlined unconditionally. +//===----------------------------------------------------------------------===// + +namespace { +struct MIPSInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *call, Operation *callable, + bool wouldBeCloned) const final { + return true; + } + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Dialect initialize +//===----------------------------------------------------------------------===// + +void MIPSDialect::initialize() { + addInterfaces(); + +#define GET_OP_LIST + addOperations< +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp.inc" + >(); +} + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.cpp.inc" diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h new file mode 100644 index 000000000000..aa99d0a63005 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSDialect.h @@ -0,0 +1,28 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_IR_MIPSDIALECT_H_ +#define IREE_COMPILER_DIALECT_MIPS_IR_MIPSDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +// clang-format off +// MIPSDialect.h.inc is generated from MIPSOps.td via: +// --dialect=mips --gen-dialect-decls MIPSDialect.h.inc +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h.inc" // IWYU pragma: keep +// clang-format on + +namespace mlir::iree_compiler::IREE::MIPS { + +// Register external BufferizableOpInterface models for MIPS ops. +// Call this from registerIreeDialects() before bufferization runs. +void registerMIPSBufferizableOpInterfaceExternalModels( + mlir::DialectRegistry ®istry); + +} // namespace mlir::iree_compiler::IREE::MIPS + +#endif // IREE_COMPILER_DIALECT_MIPS_IR_MIPSDIALECT_H_ diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp new file mode 100644 index 000000000000..c60c0b1d6c6f --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp @@ -0,0 +1,98 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +using namespace mlir; +using namespace mlir::iree_compiler::IREE::MIPS; + +//===----------------------------------------------------------------------===// +// MatmulOp — ReifyRankedShapedTypeOpInterface +// +// Returns the output shape [M, N] so IREE's dispatch formation can compute +// the workload when wrapping mips.matmul in a flow.dispatch.workgroups region. +//===----------------------------------------------------------------------===// + +LogicalResult MatmulOp::reifyResultShapes( + OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + // Result is always tensor. M from lhs dim 0, N from rhs dim 1. + reifiedReturnShapes.push_back({tensor::getMixedSize(b, getLoc(), getLhs(), 0), + tensor::getMixedSize(b, getLoc(), getRhs(), 1)}); + return success(); +} + +//===----------------------------------------------------------------------===// +// MatmulOp — MemoryEffectsOpInterface +// +// In the tensor domain, ops are nominally pure (tensors are values, not memory). +// However mips.matmul uses DPS — the init operand logically "carries" the +// result. We declare read on lhs/rhs and read+write on init so that alias +// analyses outside of bufferization correctly treat init as modified. +//===----------------------------------------------------------------------===// + +void MatmulOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getLhsMutable(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getRhsMutable(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), &getInitMutable(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &getInitMutable(), + SideEffects::DefaultResource::get()); +} + +//===----------------------------------------------------------------------===// +// MatmulOp — Verifier +//===----------------------------------------------------------------------===// + +LogicalResult MatmulOp::verify() { + auto shape = [](Value v) { + return cast(v.getType()).getShape(); + }; + auto elemTy = [](Value v) { + return cast(v.getType()).getElementType(); + }; + + // All operands must be 2-D tensors. + for (Value v : {getLhs(), getRhs(), getInit()}) { + if (cast(v.getType()).getRank() != 2) + return emitOpError("all operands must be 2-D ranked tensors"); + } + + // Dimension compatibility: lhs[M x K], rhs[K x N], init[M x N]. + // Only validate static dimensions; dynamic dims are checked at runtime. + auto compat = [](int64_t a, int64_t b) { + return ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b; + }; + if (!compat(shape(getLhs())[0], shape(getInit())[0])) + return emitOpError("lhs dim 0 (M) must match init dim 0 (M)"); + if (!compat(shape(getLhs())[1], shape(getRhs())[0])) + return emitOpError("lhs dim 1 (K) must match rhs dim 0 (K)"); + if (!compat(shape(getRhs())[1], shape(getInit())[1])) + return emitOpError("rhs dim 1 (N) must match init dim 1 (N)"); + + // All element types must match. + if (elemTy(getLhs()) != elemTy(getRhs()) || elemTy(getLhs()) != elemTy(getInit())) + return emitOpError("element types of all operands must match"); + + // Result type must match init type (both tensor). + if (getResult().getType() != getInit().getType()) + return emitOpError("result type must match init type"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// TableGen generated op definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.cpp.inc" diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h new file mode 100644 index 000000000000..dc2881aca2e3 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.h @@ -0,0 +1,25 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_IR_MIPSOPS_H_ +#define IREE_COMPILER_DIALECT_MIPS_IR_MIPSOPS_H_ + +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +// clang-format off + +#define GET_OP_CLASSES +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h.inc" // IWYU pragma: export + +// clang-format on + +#endif // IREE_COMPILER_DIALECT_MIPS_IR_MIPSOPS_H_ diff --git a/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td new file mode 100644 index 000000000000..e87073fe3b99 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/IR/MIPSOps.td @@ -0,0 +1,71 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_DIALECT_MIPS_OPS +#define IREE_DIALECT_MIPS_OPS + +include "iree/compiler/Dialect/MIPS/IR/MIPSBase.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// mips.matmul — tensor-only semantic op +// +// This op exists exclusively in the tensor domain. It is eliminated during +// One-Shot Bufferize: the BufferizableOpInterface implementation allocates +// the output memref and emits func.call @mips_matmul(...) directly, so no +// memref form of this op ever exists in the IR. +//===----------------------------------------------------------------------===// + +def MIPS_MatmulOp : MIPS_Op<"matmul", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DestinationStyleOpInterface +]> { + let summary = "MIPS matrix multiplication (tensor domain): result = lhs * rhs"; + let description = [{ + Computes a 2-D matrix multiplication in the tensor domain using + destination-passing style (DPS). The caller provides an `init` tensor that + One-Shot Bufferize uses to determine the output buffer — typically + `bufferization.alloc_tensor` for a fresh allocation. + + The semantic is: `result[m, n] = sum_k(lhs[m, k] * rhs[k, n])`. + + This op is created by the Torch -> MIPS conversion pass from `torch.aten.mm` + and is eliminated entirely during bufferization: the BufferizableOpInterface + implementation emits `func.call @mips_matmul` directly with the bufferized + memref operands. No memref-form `mips.matmul` is ever produced. + + Example: + ```mlir + %result = mips.matmul %A, %B, %init + : tensor<4x8xf32>, tensor<8x4xf32>, tensor<4x4xf32> -> tensor<4x4xf32> + ``` + }]; + + let arguments = (ins + AnyRankedTensor:$lhs, // [M x K] + AnyRankedTensor:$rhs, // [K x N] + AnyRankedTensor:$init // [M x N] — DPS destination (typically alloc_tensor) + ); + + let results = (outs AnyRankedTensor:$result); // [M x N] + + let assemblyFormat = [{ + $lhs `,` $rhs `,` $init attr-dict `:` + type($lhs) `,` type($rhs) `,` type($init) `->` type($result) + }]; + + let extraClassDeclaration = [{ + // DestinationStyleOpInterface: init is the DPS output operand. + MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } + }]; + + let hasVerifier = 1; +} + +#endif // IREE_DIALECT_MIPS_OPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..b37cc85e3b29 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_add_all_subdirs() + +# ─── Tablegen: generate Passes.h.inc from Passes.td ────────────────────────── + +iree_tablegen_library( + NAME + PassesIncGen + TD_FILE + "Passes.td" + OUTS + --gen-pass-decls Passes.h.inc +) + +# ─── Header-only pass declarations ─────────────────────────────────────────── + +iree_cc_library( + NAME + PassHeaders + HDRS + "Passes.h" + "Passes.h.inc" + DEPS + ::PassesIncGen + MLIRPass + MLIRTransforms + PUBLIC +) + +# ─── Full transforms library ───────────────────────────────────────────────── + +iree_cc_library( + NAME + Transforms + HDRS + "Passes.h" + SRCS + "LowerMIPSToFuncCall.cpp" + "Passes.cpp" + DEPS + ::PassHeaders + ::PassesIncGen + iree::compiler::Dialect::MIPS::IR + MLIRFuncDialect + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRTransformUtils + MLIRTransforms + PUBLIC +) diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp new file mode 100644 index 000000000000..df9caa9608f1 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/LowerMIPSToFuncCall.cpp @@ -0,0 +1,36 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// mips.matmul is a tensor-only op that is eliminated entirely during +// One-Shot Bufferize: the BufferizableOpInterface implementation in +// MIPSBufferizableOpInterface.cpp emits func.call @my_matmul_kernel directly. +// +// This pass is therefore a no-op and exists only for registration purposes +// (so that --iree-mips-lower-to-func-call can be specified on the command line +// without error, and so that any pipeline that references it still compiles). + +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep +#include "mlir/Dialect/MemRef/IR/MemRef.h" // IWYU pragma: keep +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::MIPS { + +#define GEN_PASS_DEF_LOWERMIPSTOFUNCCALLPASS +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h.inc" + +namespace { + +struct LowerMIPSToFuncCallPass + : impl::LowerMIPSToFuncCallPassBase { + void runOnOperation() override { + // mips.matmul is eliminated during One-Shot Bufferize (see + // MIPSBufferizableOpInterface.cpp). No work to do here. + } +}; + +} // namespace +} // namespace mlir::iree_compiler::IREE::MIPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp new file mode 100644 index 000000000000..751ea73789d8 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.cpp @@ -0,0 +1,18 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" + +namespace mlir::iree_compiler::IREE::MIPS { + +namespace { +#define GEN_PASS_REGISTRATION +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h.inc" +} // namespace + +void registerMIPSPasses() { registerPasses(); } + +} // namespace mlir::iree_compiler::IREE::MIPS diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h new file mode 100644 index 000000000000..8b4390167407 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.h @@ -0,0 +1,33 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES_H_ +#define IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::IREE::MIPS { + +//===----------------------------------------------------------------------===// +// Pass factory functions (generated by tablegen + implemented in Passes.cpp) +//===----------------------------------------------------------------------===// + +/// Creates the pass that lowers memref-form mips.matmul to +/// func.call @my_matmul_kernel. +std::unique_ptr createLowerMIPSToFuncCallPass(); + +/// Registers all MIPS passes with the global pass registry so they can be +/// invoked from the command line (e.g. `iree-opt --iree-mips-lower-to-func-call`). +void registerMIPSPasses(); + +} // namespace mlir::iree_compiler::IREE::MIPS + +// clang-format off +#define GEN_PASS_DECL +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h.inc" // IWYU pragma: keep +// clang-format on + +#endif // IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES_H_ diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td new file mode 100644 index 000000000000..b081e3a58a7c --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/Passes.td @@ -0,0 +1,49 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES +#define IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +//===----------------------------------------------------------------------===// +// LowerMIPSToFuncCallPass +// +// Converts memref-form mips.matmul operations (produced after one-shot +// bufferization in the LLVMCPU codegen pipeline) to: +// +// func.func private @my_matmul_kernel(...) attributes {llvm.bareptr = true} +// func.call @my_matmul_kernel(A_base, A_off, A_s0, A_s1, +// B_base, B_off, B_s0, B_s1, +// C_base, C_off, C_s0, C_s1, +// M, N, K) +// +// The pass uses memref.extract_strided_metadata to decompose each memref into +// a (base_pointer, offset, strides...) tuple matching the C ABI of the kernel. +//===----------------------------------------------------------------------===// + +def LowerMIPSToFuncCallPass : + Pass<"iree-mips-lower-to-func-call", "ModuleOp"> { + let summary = "Lower mips.matmul (memref form) to func.call @my_matmul_kernel"; + let description = [{ + Walks all mips.matmul ops in the module and replaces each one with a call + to the external C kernel `my_matmul_kernel`. The memref operands are + decomposed via `memref.extract_strided_metadata` into base-pointer + offset + + stride arguments, matching the ABI declared in my_matmul_kernel.h. + + The pass creates a `func.func private @my_matmul_kernel` declaration with + `{llvm.bareptr = true}` so that the LLVM backend passes bare float* pointers + rather than MLIR memref descriptor structs. + + This pass runs after one-shot bufferization in the LLVMCPU codegen pipeline. + }]; + let dependentDialects = [ + "::mlir::func::FuncDialect", + "::mlir::memref::MemRefDialect" + ]; +} + +#endif // IREE_COMPILER_DIALECT_MIPS_TRANSFORMS_PASSES diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt new file mode 100644 index 000000000000..725370463654 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_lit_test_suite( + NAME + lit + SRCS + "lower_to_func_call.mlir" + TOOLS + FileCheck + iree-opt +) diff --git a/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir new file mode 100644 index 000000000000..582419acf589 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/MIPS/Transforms/test/lower_to_func_call.mlir @@ -0,0 +1,49 @@ +// RUN: iree-opt --split-input-file --iree-mips-lower-to-func-call %s \ +// RUN: | FileCheck %s + +// ───────────────────────────────────────────────────────────────────────────── +// Basic static-shape: memref-form mips.matmul → func.call @my_matmul_kernel +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK: func.func private @my_matmul_kernel +// CHECK-SAME: {llvm.bareptr = true} +// +// CHECK-LABEL: func.func @lower_mips_matmul +// CHECK-NOT: mips.matmul +// CHECK: memref.extract_strided_metadata +// CHECK: call @my_matmul_kernel +module { + func.func @lower_mips_matmul(%A: memref<4x8xf32>, + %B: memref<8x4xf32>, + %C: memref<4x4xf32>) { + mips.matmul %A, %B, %C + : memref<4x8xf32>, memref<8x4xf32>, memref<4x4xf32> + return + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Multiple matmuls reuse the same @my_matmul_kernel declaration. +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK: func.func private @my_matmul_kernel +// Check that there is exactly one declaration (not two). +// CHECK-NOT: func.func private @my_matmul_kernel +// +// CHECK-LABEL: func.func @two_matmuls +// CHECK: call @my_matmul_kernel +// CHECK: call @my_matmul_kernel +module { + func.func @two_matmuls(%A: memref<2x4xf32>, + %B: memref<4x2xf32>, + %C: memref<2x2xf32>, + %D: memref<2x4xf32>, + %E: memref<4x2xf32>, + %F: memref<2x2xf32>) { + mips.matmul %A, %B, %C + : memref<2x4xf32>, memref<4x2xf32>, memref<2x2xf32> + mips.matmul %D, %E, %F + : memref<2x4xf32>, memref<4x2xf32>, memref<2x2xf32> + return + } +} diff --git a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt index 4b124334e1ef..01cc625df2af 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt +++ b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt @@ -89,6 +89,7 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils + iree::compiler::Dialect::MIPS::IR iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::TensorExt::IR iree::compiler::Dialect::TensorExt::Transforms diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 6ba8a4bd8db2..f6a5467aee5c 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/Flow/IR/FlowOps.h" #include "iree/compiler/Dialect/Flow/Transforms/ConvertRegionToWorkgroups.h" @@ -369,6 +370,9 @@ static bool isRootLikeOp(Operation *op) { return !isa(op); } + // MIPS: mips.matmul is a dispatch root (lowered to a custom C kernel call). + if (isa(op)) + return true; return isa(op); } diff --git a/compiler/src/iree/compiler/Tools/init_iree_dialects.h b/compiler/src/iree/compiler/Tools/init_iree_dialects.h index c47ae6cb4368..78f1d3768f68 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_iree_dialects.h @@ -23,6 +23,7 @@ #include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSDialect.h" #include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" #include "iree/compiler/Dialect/TensorExt/IR/TensorExtDialect.h" #include "iree/compiler/Dialect/Util/IR/UtilDialect.h" @@ -50,6 +51,7 @@ inline void registerIreeDialects(DialectRegistry ®istry) { IREE::HAL::Loader::HALLoaderDialect, IREE::IO::Parameters::IOParametersDialect, IREE::LinalgExt::IREELinalgExtDialect, + IREE::MIPS::MIPSDialect, IREE::PCF::PCFDialect, IREE::Encoding::IREEEncodingDialect, IREE::Stream::StreamDialect, @@ -65,6 +67,7 @@ inline void registerIreeDialects(DialectRegistry ®istry) { registerCodegenInterfaces(registry); registerGlobalOptimizationInterfaces(registry); registerUKernelBufferizationInterface(registry); + IREE::MIPS::registerMIPSBufferizableOpInterfaceExternalModels(registry); // Register transform dialect extensions. registerTransformDialectPreprocessingExtension(registry); diff --git a/compiler/src/iree/compiler/Tools/init_iree_passes.h b/compiler/src/iree/compiler/Tools/init_iree_passes.h index 6f7de0752f45..d113b8d61d3d 100644 --- a/compiler/src/iree/compiler/Tools/init_iree_passes.h +++ b/compiler/src/iree/compiler/Tools/init_iree_passes.h @@ -20,6 +20,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" #include "iree/compiler/Dialect/Stream/Transforms/Passes.h" #include "iree/compiler/Dialect/TensorExt/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" @@ -62,6 +63,7 @@ inline void registerAllIreePasses() { IREE::HAL::Loader::registerHALLoaderPasses(); IREE::IO::Parameters::registerParametersPasses(); IREE::LinalgExt::registerPasses(); + IREE::MIPS::registerMIPSPasses(); IREE::Stream::registerStreamPasses(); IREE::TensorExt::registerPasses(); IREE::Util::registerUtilPasses(); From f3c7ba519ff06a101f8378aba30a3a39feb99aac Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 15:10:58 +0530 Subject: [PATCH 2/5] [MIPS] Add Torch-to-MIPS conversion pass and LLVMCPU pipeline wiring Adds the frontend conversion from torch.aten.mm to mips.matmul and wires the MIPS dialect into the IREE LLVMCPU codegen pipeline. Torch InputConversion changes: - ConvertTorchToMIPS.cpp: ConvertAtenMmToMIPSMatmul pattern rewrites torch.aten.mm to mips.matmul with a zero-initialized init tensor (bufferization.alloc_tensor). The pass runs before the standard ConvertTorchToLinalgPass so mips.matmul takes precedence over the generic linalg.matmul path. - Passes.td / Passes.h / Passes.cpp: declare ConvertTorchToMIPSPass and add it to the torch-to-iree pipeline under the use-mips-matmul option. - test/convert_torch_to_mips.mlir: FileCheck test verifying that torch.aten.mm is replaced by mips.matmul after the pass. LLVMCPU codegen pipeline changes: - Passes.cpp: insert LowerMIPSToFuncCallPass (no-op stub) in the post-bufferize section of buildLLVMCPUCodegenPassPipeline. The actual lowering to func.call is performed during One-Shot Bufferize by MIPSBufferizableOpInterface; this stub ensures the pass slot is reserved for future use and keeps the pipeline definition explicit. - CMakeLists.txt: add iree_compiler_Dialect_MIPS_Transforms_Transforms dependency to the LLVMCPU codegen target. --- .../Torch/InputConversion/CMakeLists.txt | 2 + .../InputConversion/ConvertTorchToMIPS.cpp | 155 ++++++++++++++++++ .../input/Torch/InputConversion/Passes.cpp | 5 + .../input/Torch/InputConversion/Passes.h | 6 + .../input/Torch/InputConversion/Passes.td | 10 ++ .../Torch/InputConversion/test/CMakeLists.txt | 1 + .../test/convert_torch_to_mips.mlir | 37 +++++ .../compiler/Codegen/LLVMCPU/CMakeLists.txt | 1 + .../iree/compiler/Codegen/LLVMCPU/Passes.cpp | 5 + 9 files changed, 222 insertions(+) create mode 100644 compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp create mode 100644 compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir diff --git a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt index 2738c8d9fa99..d9d0fa598477 100644 --- a/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt +++ b/compiler/plugins/input/Torch/InputConversion/CMakeLists.txt @@ -37,6 +37,7 @@ iree_cc_library( "BindSymbolicShapes.cpp" "BitCastTensor.cpp" "ConvertTMTensorToLinalgExt.cpp" + "ConvertTorchToMIPS.cpp" "ConvertTorchUnstructuredToLinalgExt.cpp" "FuncConversion.cpp" "SetStrictSymbolicShapes.cpp" @@ -57,6 +58,7 @@ iree_cc_library( iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::IR iree::compiler::Dialect::LinalgExt::IR + iree::compiler::Dialect::MIPS::IR iree::compiler::Dialect::Stream::IR iree::compiler::Dialect::TensorExt::IR PUBLIC diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp new file mode 100644 index 000000000000..1999f8695dfd --- /dev/null +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp @@ -0,0 +1,155 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Converts torch.aten.mm → mips.matmul. +// +// The pattern runs inside the Torch input-conversion pipeline, BEFORE +// createConvertTorchToLinalgPass(), so it intercepts aten.mm first. +// +// Since torch ops carry ValueTensorType (torch's tensor type), the pattern: +// 1. Casts operands to builtin RankedTensorType via ToBuiltinTensorOp. +// 2. Creates a zero-initialised init tensor (Destination Passing Style). +// 3. Emits mips.matmul on builtin tensors. +// 4. Casts the result back to ValueTensorType via FromBuiltinTensorOp. +// +// This mirrors the approach in ConvertTorchUnstructuredToLinalgExt.cpp. + +#include "compiler/plugins/input/Torch/InputConversion/Passes.h" +#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" + +namespace mlir::iree_compiler::TorchInput { + +#define GEN_PASS_DEF_CONVERTTORCHTOMIPSPASS +#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Helper: create a zero-filled tensor of a given shape and element type. +// Accepts (M, N) as dynamic Value dimensions. +//===----------------------------------------------------------------------===// + +static Value createZeroTensor(PatternRewriter &rewriter, Location loc, + RankedTensorType ty, ValueRange dynSizes) { + Value empty = tensor::EmptyOp::create(rewriter, loc, ty, dynSizes); + Attribute zeroAttr = rewriter.getZeroAttr(ty.getElementType()); + Value zero = arith::ConstantOp::create(rewriter, loc, cast(zeroAttr)); + return linalg::FillOp::create(rewriter, loc, zero, empty).result(); +} + +//===----------------------------------------------------------------------===// +// Pattern: torch.aten.mm → mips.matmul +//===----------------------------------------------------------------------===// + +struct ConvertAtenMmToMIPSMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(torch::Torch::AtenMmOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // ---------------------------------------------------------------- + // 1. Verify that we have supported tensor types. + // ---------------------------------------------------------------- + auto lhsTorchTy = + dyn_cast(op.getSelf().getType()); + auto rhsTorchTy = + dyn_cast(op.getMat2().getType()); + auto resultTorchTy = + dyn_cast(op.getType()); + + if (!lhsTorchTy || !rhsTorchTy || !resultTorchTy) + return rewriter.notifyMatchFailure(op, "expected ValueTensorType"); + + // Only handle f32 for now (extensible). + if (!lhsTorchTy.getDtype().isF32()) + return rewriter.notifyMatchFailure(op, "only f32 supported"); + + // ---------------------------------------------------------------- + // 2. Cast operands from torch ValueTensorType → builtin RankedTensorType. + // ---------------------------------------------------------------- + auto lhsBuiltinTy = + dyn_cast_or_null(lhsTorchTy.toBuiltinTensor()); + auto rhsBuiltinTy = + dyn_cast_or_null(rhsTorchTy.toBuiltinTensor()); + auto resultBuiltinTy = + dyn_cast_or_null(resultTorchTy.toBuiltinTensor()); + + if (!lhsBuiltinTy || !rhsBuiltinTy || !resultBuiltinTy || + lhsBuiltinTy.getRank() != 2 || rhsBuiltinTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "expected 2-D ranked tensors"); + + Value lhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, lhsBuiltinTy, op.getSelf()); + Value rhs = torch::TorchConversion::ToBuiltinTensorOp::create( + rewriter, loc, rhsBuiltinTy, op.getMat2()); + + // ---------------------------------------------------------------- + // 3. Collect dynamic dimension values for the result tensor (M, N). + // ---------------------------------------------------------------- + SmallVector dynSizes; + if (resultBuiltinTy.isDynamicDim(0)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, lhs, 0)); + if (resultBuiltinTy.isDynamicDim(1)) + dynSizes.push_back(tensor::DimOp::create(rewriter, loc, rhs, 1)); + + // ---------------------------------------------------------------- + // 4. Create a zero-initialised init tensor for DPS output. + // ---------------------------------------------------------------- + Value init = createZeroTensor(rewriter, loc, resultBuiltinTy, dynSizes); + + // ---------------------------------------------------------------- + // 5. Emit mips.matmul on builtin tensors. + // ---------------------------------------------------------------- + Value result = + IREE::MIPS::MatmulOp::create(rewriter, loc, TypeRange{resultBuiltinTy}, + lhs, rhs, init) + .getResult(); + + // ---------------------------------------------------------------- + // 6. Cast result back to ValueTensorType so downstream torch passes can + // still operate on it until the type finalisation pass runs. + // ---------------------------------------------------------------- + Value torchResult = torch::TorchConversion::FromBuiltinTensorOp::create( + rewriter, loc, resultTorchTy, result); + + rewriter.replaceOp(op, torchResult); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +struct ConvertTorchToMIPSPass + : impl::ConvertTorchToMIPSPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::iree_compiler::TorchInput diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.cpp b/compiler/plugins/input/Torch/InputConversion/Passes.cpp index ce0cb9c34f36..82fdf118b9d5 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/Passes.cpp @@ -63,6 +63,11 @@ void createTorchToIREEPipeline( pm.addNestedPass(torch::createConvertTorchToTensorPass()); pm.addNestedPass( TorchInput::createConvertTorchUnstructuredToLinalgExtPass()); + // MIPS: When enabled, intercept aten.mm before the standard torch->linalg + // pass and route it through mips.matmul -> func.call @my_matmul_kernel. + if (options.useMIPSMatmul) { + pm.addNestedPass(TorchInput::createConvertTorchToMIPSPass()); + } pm.addNestedPass(torch::createConvertTorchToLinalgPass()); pm.addNestedPass(createCSEPass()); pm.addNestedPass(torch::createConvertTorchToSCFPass()); diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.h b/compiler/plugins/input/Torch/InputConversion/Passes.h index 23995b21a943..8eb22c3a6034 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.h +++ b/compiler/plugins/input/Torch/InputConversion/Passes.h @@ -43,6 +43,12 @@ struct TorchToIREELoweringPipelineOptions "program inputs. This buffer will be used for storing transient " "memory and must be provided by the user."), llvm::cl::init(false)}; + Option useMIPSMatmul{ + *this, "use-mips-matmul", + llvm::cl::desc("If enabled, lowers torch.aten.mm through the MIPS " + "custom dialect (mips.matmul) instead of the standard " + "torch->linalg path."), + llvm::cl::init(false)}; }; // Creates a pipeline that lowers from the torch backend contract to IREE. diff --git a/compiler/plugins/input/Torch/InputConversion/Passes.td b/compiler/plugins/input/Torch/InputConversion/Passes.td index a868d4bb8354..fb55ba7d189f 100644 --- a/compiler/plugins/input/Torch/InputConversion/Passes.td +++ b/compiler/plugins/input/Torch/InputConversion/Passes.td @@ -29,6 +29,16 @@ def ConvertTorchUnstructuredToLinalgExtPass : let summary = "Convert unstructured Torch ops to LinalgExt ops"; } +def ConvertTorchToMIPSPass : + InterfacePass<"torch-iree-to-mips-matmul", "mlir::FunctionOpInterface"> { + let summary = "Convert torch.aten.mm to mips.matmul"; + let description = [{ + Intercepts torch.aten.mm before the standard torch->linalg conversion and + replaces it with mips.matmul. The mips.matmul op is later lowered to a + func.call @my_matmul_kernel by LowerMIPSToFuncCallPass after bufferization. + }]; +} + def SetStrictSymbolicShapesPass : InterfacePass<"torch-iree-set-strict-symbolic-shapes", "mlir::FunctionOpInterface"> { let summary = "Adds the attribute indicating strict symbolic shapes in Torch IR"; diff --git a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt index b1785a708878..3d451b68f13a 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt +++ b/compiler/plugins/input/Torch/InputConversion/test/CMakeLists.txt @@ -9,6 +9,7 @@ iree_lit_test_suite( "attention.mlir" "bind_symbolic_shapes.mlir" "bitcast_tensor.mlir" + "convert_torch_to_mips.mlir" "func_conversion.mlir" "func_conversion_invalid.mlir" "func_conversion_transients.mlir" diff --git a/compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir b/compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir new file mode 100644 index 000000000000..98d5407fefc5 --- /dev/null +++ b/compiler/plugins/input/Torch/InputConversion/test/convert_torch_to_mips.mlir @@ -0,0 +1,37 @@ +// RUN: iree-opt --split-input-file \ +// RUN: --pass-pipeline="builtin.module(func.func(torch-iree-to-mips-matmul))" \ +// RUN: %s | FileCheck %s + +// ───────────────────────────────────────────────────────────────────────────── +// Static-shape: torch.aten.mm on f32 tensors → mips.matmul +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK-LABEL: func.func @mm_static +// CHECK: torch_c.to_builtin_tensor {{.*}} -> tensor<4x8xf32> +// CHECK: torch_c.to_builtin_tensor {{.*}} -> tensor<8x4xf32> +// CHECK: mips.matmul {{.*}} : tensor<4x8xf32>, tensor<8x4xf32>, tensor<4x4xf32> -> tensor<4x4xf32> +// CHECK-NOT: torch.aten.mm +func.func @mm_static(%A: !torch.vtensor<[4,8],f32>, + %B: !torch.vtensor<[8,4],f32>) + -> !torch.vtensor<[4,4],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,4],f32> + -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> +} + +// ───────────────────────────────────────────────────────────────────────────── +// Non-f32 (i32) should be left untouched (pattern rejects non-f32 dtypes). +// ───────────────────────────────────────────────────────────────────────────── + +// CHECK-LABEL: func.func @mm_i32_unchanged +// CHECK-NOT: mips.matmul +// CHECK: torch.aten.mm +func.func @mm_i32_unchanged(%A: !torch.vtensor<[4,8],si32>, + %B: !torch.vtensor<[8,4],si32>) + -> !torch.vtensor<[4,4],si32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,8],si32>, !torch.vtensor<[8,4],si32> + -> !torch.vtensor<[4,4],si32> + return %0 : !torch.vtensor<[4,4],si32> +} diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index ebd8d7d93755..173e4cdc06db 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -166,6 +166,7 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Dialect::LinalgExt::Utils + iree::compiler::Dialect::MIPS::Transforms iree::compiler::Dialect::TensorExt::IR iree::compiler::Dialect::Util::IR iree::compiler::Dialect::Util::Transforms diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp index 771839bd3400..c0d9ed9981e4 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp @@ -14,6 +14,7 @@ #include "iree/compiler/Codegen/LLVMCPU/Passes.h" #include "iree/compiler/Codegen/Utils/CodegenOptions.h" #include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" +#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h" #include "iree/compiler/Dialect/Util/Transforms/Passes.h" #include "iree/compiler/Utils/PassUtils.h" #include "llvm/ADT/TypeSwitch.h" @@ -518,6 +519,10 @@ static void addLowerToLLVMPasses(OpPassManager &modulePassManager, FunctionLikeNest(modulePassManager) .addPass(createEraseHALDescriptorTypeFromMemRefPass); + // mips.matmul is eliminated during One-Shot Bufferize (func.call emitted + // directly by MIPSBufferizableOpInterface). This pass is now a no-op. + modulePassManager.addPass(IREE::MIPS::createLowerMIPSToFuncCallPass()); + // Lower `ukernel.*` ops to function calls modulePassManager.addPass(createLowerUKernelOpsToCallsPass()); From 2958952bf1787aabf339c8de5b635c0749ff62a3 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 15:11:49 +0530 Subject: [PATCH 3/5] [MIPS] Add my_matmul_kernel IREE executable plugin library Provides the runtime kernel that backs mips.matmul dispatches. The kernel is packaged as an IREE executable plugin (implements iree_hal_executable_plugin_query) rather than a plain shared library because IREE's LLVMCPU system-dylib dispatch format resolves external function references through an internal import table (not standard ELF dynamic linking). The plugin's resolve() callback maps the symbol name "my_matmul_kernel" to the import-ABI wrapper. Usage: iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ... --- runtime/src/iree/builtins/mips/CMakeLists.txt | 31 ++++ .../src/iree/builtins/mips/my_matmul_kernel.c | 141 ++++++++++++++++++ .../src/iree/builtins/mips/my_matmul_kernel.h | 40 +++++ 3 files changed, 212 insertions(+) create mode 100644 runtime/src/iree/builtins/mips/CMakeLists.txt create mode 100644 runtime/src/iree/builtins/mips/my_matmul_kernel.c create mode 100644 runtime/src/iree/builtins/mips/my_matmul_kernel.h diff --git a/runtime/src/iree/builtins/mips/CMakeLists.txt b/runtime/src/iree/builtins/mips/CMakeLists.txt new file mode 100644 index 000000000000..19574a5148f7 --- /dev/null +++ b/runtime/src/iree/builtins/mips/CMakeLists.txt @@ -0,0 +1,31 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Shared library containing the MIPS custom matmul kernel. +# +# The kernel is exposed as an IREE executable plugin (implements +# iree_hal_executable_plugin_query) so that iree-run-module can load it via: +# iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ... +# +# The IREE runtime resolves the @my_matmul_kernel import from the dispatch +# executable through this plugin's resolve() function. + +add_library(my_matmul_kernel SHARED + my_matmul_kernel.c +) + +target_include_directories(my_matmul_kernel + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} + PRIVATE + # For iree/hal/local/executable_plugin.h (standalone C header, no deps) + ${PROJECT_SOURCE_DIR}/runtime/src +) + +set_target_properties(my_matmul_kernel PROPERTIES + C_VISIBILITY_PRESET default + POSITION_INDEPENDENT_CODE ON +) diff --git a/runtime/src/iree/builtins/mips/my_matmul_kernel.c b/runtime/src/iree/builtins/mips/my_matmul_kernel.c new file mode 100644 index 000000000000..e952e72b22e0 --- /dev/null +++ b/runtime/src/iree/builtins/mips/my_matmul_kernel.c @@ -0,0 +1,141 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// Naive triple-loop matmul kernel exposed as an IREE executable plugin. +// +// The IREE-compiled dispatch calls @my_matmul_kernel through the IREE import +// mechanism (not direct dynamic linking). At runtime the plugin is registered +// via: +// iree-run-module --executable_plugin=libmy_matmul_kernel.dylib ... +// +// IREE import calling convention (for all imports): +// int fn(void* params_ptr, void* context, void* reserved) +// +// The params_ptr points to a packed struct whose fields mirror the +// func.call arguments emitted by LowerMIPSToFuncCall, in order: +// float* A, int64_t A_off, A_s0, A_s1, +// float* B, int64_t B_off, B_s0, B_s1, +// float* C, int64_t C_off, C_s0, C_s1, +// int64_t M, int64_t N, int64_t K + +#include +#include + +// IREE standalone plugin header — only requires C99 standard headers. +#include "iree/hal/local/executable_plugin.h" + +//===----------------------------------------------------------------------===// +// Kernel implementation +//===----------------------------------------------------------------------===// + +// Packed argument struct mirroring the func.call arguments from +// LowerMIPSToFuncCall.cpp. +typedef struct { + float *A; + int64_t A_off, A_s0, A_s1; + float *B; + int64_t B_off, B_s0, B_s1; + float *C; + int64_t C_off, C_s0, C_s1; + int64_t M, N, K; +} my_matmul_kernel_args_t; + +// Import thunk wrapper — called by the IREE runtime with the packed args. +static int my_matmul_kernel_import(void *params_ptr, void *context, + void *reserved) { + (void)context; + (void)reserved; + const my_matmul_kernel_args_t *a = (const my_matmul_kernel_args_t *)params_ptr; + + float *A = a->A + a->A_off; + float *B = a->B + a->B_off; + float *C = a->C + a->C_off; + int64_t M = a->M, N = a->N, K = a->K; + int64_t A_s0 = a->A_s0, A_s1 = a->A_s1; + int64_t B_s0 = a->B_s0, B_s1 = a->B_s1; + int64_t C_s0 = a->C_s0, C_s1 = a->C_s1; + + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float acc = 0.0f; + for (int64_t k = 0; k < K; ++k) + acc += A[m * A_s0 + k * A_s1] * B[k * B_s0 + n * B_s1]; + C[m * C_s0 + n * C_s1] = acc; + } + } + return 0; +} + +//===----------------------------------------------------------------------===// +// IREE Executable Plugin interface +//===----------------------------------------------------------------------===// + +static iree_hal_executable_plugin_status_t plugin_load( + const iree_hal_executable_plugin_environment_v0_t *environment, + size_t param_count, + const iree_hal_executable_plugin_string_pair_t *params, void **out_self) { + (void)environment; + (void)param_count; + (void)params; + *out_self = NULL; // stateless plugin + return iree_hal_executable_plugin_ok_status(); +} + +static void plugin_unload(void *self) { (void)self; } + +static iree_hal_executable_plugin_status_t plugin_resolve( + void *self, const iree_hal_executable_plugin_resolve_params_v0_t *params, + iree_hal_executable_plugin_resolution_t *out_resolution) { + (void)self; + *out_resolution = 0; + bool any_required_not_found = false; + + for (size_t i = 0; i < params->count; ++i) { + if (params->out_fn_ptrs[i]) continue; // already resolved + const char *name = params->symbol_names[i]; + bool optional = iree_hal_executable_plugin_import_is_optional(name); + if (optional) ++name; // skip the leading '?' + + if (iree_hal_executable_plugin_strcmp(name, "my_matmul_kernel") == 0) { + params->out_fn_ptrs[i] = my_matmul_kernel_import; + params->out_fn_contexts[i] = NULL; + } else { + if (optional) { + *out_resolution |= + IREE_HAL_EXECUTABLE_PLUGIN_RESOLUTION_MISSING_OPTIONAL; + } else { + any_required_not_found = true; + } + } + } + + return any_required_not_found + ? iree_hal_executable_plugin_status_from_code( + IREE_HAL_EXECUTABLE_PLUGIN_STATUS_NOT_FOUND) + : iree_hal_executable_plugin_ok_status(); +} + +// Exported entry point queried by the IREE runtime (via dlsym). +IREE_HAL_EXECUTABLE_PLUGIN_EXPORT const iree_hal_executable_plugin_header_t ** +iree_hal_executable_plugin_query( + iree_hal_executable_plugin_version_t max_version, void *reserved) { + static const iree_hal_executable_plugin_header_t header = { + .version = IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST, + .name = "mips_matmul", + .description = "MIPS custom matmul kernel plugin", + .features = IREE_HAL_EXECUTABLE_PLUGIN_FEATURE_STANDALONE, + .sanitizer = IREE_HAL_EXECUTABLE_PLUGIN_SANITIZER_KIND, + }; + static const iree_hal_executable_plugin_v0_t plugin = { + .header = &header, + .load = plugin_load, + .unload = plugin_unload, + .resolve = plugin_resolve, + }; + return max_version <= IREE_HAL_EXECUTABLE_PLUGIN_VERSION_LATEST + ? (const iree_hal_executable_plugin_header_t **)&plugin + : NULL; +} diff --git a/runtime/src/iree/builtins/mips/my_matmul_kernel.h b/runtime/src/iree/builtins/mips/my_matmul_kernel.h new file mode 100644 index 000000000000..8f8319d6dad9 --- /dev/null +++ b/runtime/src/iree/builtins/mips/my_matmul_kernel.h @@ -0,0 +1,40 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// C ABI for the MIPS custom matmul kernel. +// +// The kernel computes: +// C[m, n] = sum_k A[m, k] * B[k, n] +// +// Each matrix is passed as a base pointer plus explicit strided-layout +// parameters that match what MLIR's memref.extract_strided_metadata produces. +// This matches the calling convention emitted by LowerMIPSToFuncCallPass. + +#ifndef IREE_BUILTINS_MIPS_MY_MATMUL_KERNEL_H_ +#define IREE_BUILTINS_MIPS_MY_MATMUL_KERNEL_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// void my_matmul_kernel( +// float *A, int64_t A_offset, int64_t A_stride0, int64_t A_stride1, +// float *B, int64_t B_offset, int64_t B_stride0, int64_t B_stride1, +// float *C, int64_t C_offset, int64_t C_stride0, int64_t C_stride1, +// int64_t M, int64_t N, int64_t K); +void my_matmul_kernel(float *A, int64_t A_offset, int64_t A_stride0, + int64_t A_stride1, float *B, int64_t B_offset, + int64_t B_stride0, int64_t B_stride1, float *C, + int64_t C_offset, int64_t C_stride0, int64_t C_stride1, + int64_t M, int64_t N, int64_t K); + +#ifdef __cplusplus +} +#endif + +#endif // IREE_BUILTINS_MIPS_MY_MATMUL_KERNEL_H_ From a459d7854a5b1723b2ceedb7641aeb22a5282f17 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 15:12:17 +0530 Subject: [PATCH 4/5] [MIPS] Add compile_mips.zsh and run_mips.zsh end-to-end test scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Provides two convenience scripts for exercising the full torch.aten.mm → mips.matmul → func.call → vmfb → kernel plugin pipeline. compile_mips.zsh: Step 1 — verifies torch.aten.mm is converted to mips.matmul by the ConvertTorchToMIPSPass (iree-opt smoke check). Step 2 — runs the full torch-to-iree pipeline with use-mips-matmul=true, producing the IREE input IR (/tmp/mm_iree.mlir). Step 3 — compiles the IREE input IR to a vmfb (/tmp/mm_mips.vmfb) with --mlir-print-ir-after-all, writing a per-pass IR dump to /tmp/mm_mips_ir_dump.mlir for debugging. run_mips.zsh: Runs iree-run-module with --executable_plugin pointing at the built libmy_matmul_kernel.dylib. Tests A * I = A (matrix multiplied by identity) and prints the result for visual verification. --- compile_mips.zsh | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ run_mips.zsh | 42 ++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100755 compile_mips.zsh create mode 100755 run_mips.zsh diff --git a/compile_mips.zsh b/compile_mips.zsh new file mode 100755 index 000000000000..104fc98332d5 --- /dev/null +++ b/compile_mips.zsh @@ -0,0 +1,62 @@ +#!/bin/zsh +# compile_mips.zsh +# +# Compiles a simple torch.aten.mm through the MIPS custom matmul path: +# torch.aten.mm +# -> mips.matmul (ConvertTorchToMIPSPass) +# -> flow.dispatch(...) (IREE dispatch formation) +# -> func.call @my_matmul_kernel (bufferize via BufferizableOpInterface) +# -> LLVM / vmfb (iree-compile LLVMCPU backend) +# +# Output: /tmp/mm_mips.vmfb +# IR dump: /tmp/mm_mips_ir_dump.mlir (--mlir-print-ir-after-all) + +set -e # exit on first error + +BUILD=/Users/gauravshukla/MLIR_Work/mips/iree-build +IREE_OPT=$BUILD/tools/iree-opt +IREE_COMPILE=$BUILD/tools/iree-compile + +# ── Input: 4x4 f32 matrix multiply ──────────────────────────────────────────── +cat > /tmp/mm_torch.mlir << 'EOF' +module { + func.func @mm(%A: !torch.vtensor<[4,4],f32>, + %B: !torch.vtensor<[4,4],f32>) + -> !torch.vtensor<[4,4],f32> { + %0 = torch.aten.mm %A, %B + : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> + -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> + } +} +EOF + +# ── Step 1: Verify torch.aten.mm → mips.matmul ──────────────────────────────── +echo "==> Step 1: verifying torch.aten.mm → mips.matmul" +$IREE_OPT \ + --pass-pipeline="builtin.module(func.func(torch-iree-to-mips-matmul))" \ + /tmp/mm_torch.mlir \ + | grep -q "mips.matmul" && echo " [OK] mips.matmul found in IR" + +# ── Step 2: Full torch → IREE input IR (with MIPS path enabled) ─────────────── +echo "==> Step 2: torch → IREE input IR (use-mips-matmul=true)" +$IREE_OPT \ + --pass-pipeline="builtin.module(torch-to-iree{use-mips-matmul=true})" \ + /tmp/mm_torch.mlir \ + -o /tmp/mm_iree.mlir + +# ── Step 3: IREE input IR → vmfb (dispatch + bufferize + LLVM) ─────────────── +IR_DUMP=/tmp/mm_mips_ir_dump.mlir +echo "==> Step 3: IREE input IR → vmfb (IR dump → $IR_DUMP)" +$IREE_COMPILE \ + --iree-hal-target-backends=llvm-cpu \ + --iree-llvmcpu-link-embedded=false \ + --mlir-print-ir-after-all \ + /tmp/mm_iree.mlir \ + -o /tmp/mm_mips.vmfb \ + 2>"$IR_DUMP" + +echo "" +echo "==> Compiled successfully: /tmp/mm_mips.vmfb" +echo " IR dump written to: $IR_DUMP" +echo " Run with: ./run_mips.zsh" diff --git a/run_mips.zsh b/run_mips.zsh new file mode 100755 index 000000000000..d215f1e8efdd --- /dev/null +++ b/run_mips.zsh @@ -0,0 +1,42 @@ +#!/bin/zsh +# run_mips.zsh +# +# Runs the vmfb produced by compile_mips.zsh. +# +# The vmfb's dispatch executable calls @my_matmul_kernel through the IREE +# import mechanism (not direct dynamic linking). The kernel is provided by +# libmy_matmul_kernel.dylib which implements the IREE executable plugin API +# (exports iree_hal_executable_plugin_query). +# +# Test: A * I = A (multiply by 4x4 identity → expect the same matrix back) + +BUILD=/Users/gauravshukla/MLIR_Work/mips/iree-build +KERNEL_LIB=$BUILD/runtime/src/iree/builtins/mips/libmy_matmul_kernel.dylib +IREE_RUN=$BUILD/tools/iree-run-module +VMFB=/tmp/mm_mips.vmfb + +if [[ ! -f $VMFB ]]; then + echo "ERROR: $VMFB not found. Run compile_mips.zsh first." + exit 1 +fi + +if [[ ! -f $KERNEL_LIB ]]; then + echo "ERROR: $KERNEL_LIB not found. Build my_matmul_kernel target first." + exit 1 +fi + +# A = 1..16 (row-major 4x4), B = 4x4 identity +A="4x4xf32=1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16" +B="4x4xf32=1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1" + +echo "==> Running mm(A, I) via MIPS kernel" +echo " Kernel plugin : $KERNEL_LIB" +echo " Expected : A * I = A (rows: [1 2 3 4], [5 6 7 8], ...)" +echo "" + +$IREE_RUN \ + --executable_plugin=$KERNEL_LIB \ + --module=$VMFB \ + --function=mm \ + --input="$A" \ + --input="$B" From c14a0d20c529ae12de268e762fa4b37d727ec320 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 3 Mar 2026 16:19:24 +0530 Subject: [PATCH 5/5] [MIPS] Add iree-build.zsh build configuration script --- iree-build.zsh | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100755 iree-build.zsh diff --git a/iree-build.zsh b/iree-build.zsh new file mode 100755 index 000000000000..8738503ba098 --- /dev/null +++ b/iree-build.zsh @@ -0,0 +1,44 @@ +#!/usr/bin/env zsh + +# Exit if any command fails +set -e + +# Paths (customize as needed) +SRC_DIR="$HOME/MLIR_Work/mips" +IREE_SRC_DIR="$SRC_DIR/iree" # Path to your cloned iree +BUILD_DIR="$SRC_DIR/iree-build" + +# Configuration +CMAKE_GENERATOR="Ninja" # Change to "Unix Makefiles" if preferred +BUILD_TYPE="RelWithDebInfo" # "Debug" or "RelWithDebInfo" are alternatives +NUM_JOBS=$(sysctl -n hw.logicalcpu) # Uses all CPU cores + +# Prepare directories +mkdir -p "${BUILD_DIR}" + +# CMake configuration +cmake -S "${IREE_SRC_DIR}" -B "${BUILD_DIR}" \ + -G "${CMAKE_GENERATOR}" \ + -DCMAKE_BUILD_TYPE="${BUILD_TYPE}" \ + -DIREE_ENABLE_ASSERTIONS=ON \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_C_COMPILER=/usr/bin/clang \ + -DCMAKE_CXX_COMPILER=/usr/bin/clang++ \ + -DIREE_ENABLE_SPLIT_DWARF=ON \ + -DIREE_ENABLE_LLD=ON \ + -DIREE_TARGET_BACKEND_DEFAULTS=OFF \ + -DIREE_TARGET_BACKEND_LLVM_CPU=ON \ + -DIREE_HAL_DRIVER_DEFAULTS=OFF \ + -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \ + -DIREE_HAL_DRIVER_LOCAL_TASK=ON \ + -DIREE_BUILD_PYTHON_BINDINGS=ON \ + -DPython3_EXECUTABLE="$(which python)" + +# Build +ninja -C "${BUILD_DIR}" -j"${NUM_JOBS}" + +# Or combine all steps using a utility target +#cmake --build ../iree-build --target iree-run-tests + +echo "✅ IREE built and executed tests successfully!"