Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions compile_mips.zsh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/bin/zsh
# compile_mips.zsh
#
# Compiles a simple torch.aten.mm through the MIPS custom matmul path:
# torch.aten.mm
# -> mips.matmul (ConvertTorchToMIPSPass)
# -> flow.dispatch(...) (IREE dispatch formation)
# -> func.call @my_matmul_kernel (bufferize via BufferizableOpInterface)
# -> LLVM / vmfb (iree-compile LLVMCPU backend)
#
# Output: /tmp/mm_mips.vmfb
# IR dump: /tmp/mm_mips_ir_dump.mlir (--mlir-print-ir-after-all)

set -e # exit on first error

BUILD=/Users/gauravshukla/MLIR_Work/mips/iree-build
IREE_OPT=$BUILD/tools/iree-opt
IREE_COMPILE=$BUILD/tools/iree-compile

# ── Input: 4x4 f32 matrix multiply ────────────────────────────────────────────
cat > /tmp/mm_torch.mlir << 'EOF'
module {
func.func @mm(%A: !torch.vtensor<[4,4],f32>,
%B: !torch.vtensor<[4,4],f32>)
-> !torch.vtensor<[4,4],f32> {
%0 = torch.aten.mm %A, %B
: !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32>
-> !torch.vtensor<[4,4],f32>
return %0 : !torch.vtensor<[4,4],f32>
}
}
EOF

# ── Step 1: Verify torch.aten.mm → mips.matmul ────────────────────────────────
echo "==> Step 1: verifying torch.aten.mm → mips.matmul"
$IREE_OPT \
--pass-pipeline="builtin.module(func.func(torch-iree-to-mips-matmul))" \
/tmp/mm_torch.mlir \
| grep -q "mips.matmul" && echo " [OK] mips.matmul found in IR"

# ── Step 2: Full torch → IREE input IR (with MIPS path enabled) ───────────────
echo "==> Step 2: torch → IREE input IR (use-mips-matmul=true)"
$IREE_OPT \
--pass-pipeline="builtin.module(torch-to-iree{use-mips-matmul=true})" \
/tmp/mm_torch.mlir \
-o /tmp/mm_iree.mlir

# ── Step 3: IREE input IR → vmfb (dispatch + bufferize + LLVM) ───────────────
IR_DUMP=/tmp/mm_mips_ir_dump.mlir
echo "==> Step 3: IREE input IR → vmfb (IR dump → $IR_DUMP)"
$IREE_COMPILE \
--iree-hal-target-backends=llvm-cpu \
--iree-llvmcpu-link-embedded=false \
--mlir-print-ir-after-all \
/tmp/mm_iree.mlir \
-o /tmp/mm_mips.vmfb \
2>"$IR_DUMP"

echo ""
echo "==> Compiled successfully: /tmp/mm_mips.vmfb"
echo " IR dump written to: $IR_DUMP"
echo " Run with: ./run_mips.zsh"
2 changes: 2 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ iree_cc_library(
"BindSymbolicShapes.cpp"
"BitCastTensor.cpp"
"ConvertTMTensorToLinalgExt.cpp"
"ConvertTorchToMIPS.cpp"
"ConvertTorchUnstructuredToLinalgExt.cpp"
"FuncConversion.cpp"
"SetStrictSymbolicShapes.cpp"
Expand All @@ -57,6 +58,7 @@ iree_cc_library(
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::MIPS::IR
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::TensorExt::IR
PUBLIC
Expand Down
155 changes: 155 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/ConvertTorchToMIPS.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Converts torch.aten.mm → mips.matmul.
//
// The pattern runs inside the Torch input-conversion pipeline, BEFORE
// createConvertTorchToLinalgPass(), so it intercepts aten.mm first.
//
// Since torch ops carry ValueTensorType (torch's tensor type), the pattern:
// 1. Casts operands to builtin RankedTensorType via ToBuiltinTensorOp.
// 2. Creates a zero-initialised init tensor (Destination Passing Style).
// 3. Emits mips.matmul on builtin tensors.
// 4. Casts the result back to ValueTensorType via FromBuiltinTensorOp.
//
// This mirrors the approach in ConvertTorchUnstructuredToLinalgExt.cpp.

#include "compiler/plugins/input/Torch/InputConversion/Passes.h"
#include "iree/compiler/Dialect/MIPS/IR/MIPSOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"

namespace mlir::iree_compiler::TorchInput {

#define GEN_PASS_DEF_CONVERTTORCHTOMIPSPASS
#include "compiler/plugins/input/Torch/InputConversion/Passes.h.inc"

namespace {

//===----------------------------------------------------------------------===//
// Helper: create a zero-filled tensor of a given shape and element type.
// Accepts (M, N) as dynamic Value dimensions.
//===----------------------------------------------------------------------===//

static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
RankedTensorType ty, ValueRange dynSizes) {
Value empty = tensor::EmptyOp::create(rewriter, loc, ty, dynSizes);
Attribute zeroAttr = rewriter.getZeroAttr(ty.getElementType());
Value zero = arith::ConstantOp::create(rewriter, loc, cast<TypedAttr>(zeroAttr));
return linalg::FillOp::create(rewriter, loc, zero, empty).result();
}

//===----------------------------------------------------------------------===//
// Pattern: torch.aten.mm → mips.matmul
//===----------------------------------------------------------------------===//

struct ConvertAtenMmToMIPSMatmul
: public OpRewritePattern<torch::Torch::AtenMmOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(torch::Torch::AtenMmOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

// ----------------------------------------------------------------
// 1. Verify that we have supported tensor types.
// ----------------------------------------------------------------
auto lhsTorchTy =
dyn_cast<torch::Torch::ValueTensorType>(op.getSelf().getType());
auto rhsTorchTy =
dyn_cast<torch::Torch::ValueTensorType>(op.getMat2().getType());
auto resultTorchTy =
dyn_cast<torch::Torch::ValueTensorType>(op.getType());

if (!lhsTorchTy || !rhsTorchTy || !resultTorchTy)
return rewriter.notifyMatchFailure(op, "expected ValueTensorType");

// Only handle f32 for now (extensible).
if (!lhsTorchTy.getDtype().isF32())
return rewriter.notifyMatchFailure(op, "only f32 supported");

// ----------------------------------------------------------------
// 2. Cast operands from torch ValueTensorType → builtin RankedTensorType.
// ----------------------------------------------------------------
auto lhsBuiltinTy =
dyn_cast_or_null<RankedTensorType>(lhsTorchTy.toBuiltinTensor());
auto rhsBuiltinTy =
dyn_cast_or_null<RankedTensorType>(rhsTorchTy.toBuiltinTensor());
auto resultBuiltinTy =
dyn_cast_or_null<RankedTensorType>(resultTorchTy.toBuiltinTensor());

if (!lhsBuiltinTy || !rhsBuiltinTy || !resultBuiltinTy ||
lhsBuiltinTy.getRank() != 2 || rhsBuiltinTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "expected 2-D ranked tensors");

Value lhs = torch::TorchConversion::ToBuiltinTensorOp::create(
rewriter, loc, lhsBuiltinTy, op.getSelf());
Value rhs = torch::TorchConversion::ToBuiltinTensorOp::create(
rewriter, loc, rhsBuiltinTy, op.getMat2());

// ----------------------------------------------------------------
// 3. Collect dynamic dimension values for the result tensor (M, N).
// ----------------------------------------------------------------
SmallVector<Value> dynSizes;
if (resultBuiltinTy.isDynamicDim(0))
dynSizes.push_back(tensor::DimOp::create(rewriter, loc, lhs, 0));
if (resultBuiltinTy.isDynamicDim(1))
dynSizes.push_back(tensor::DimOp::create(rewriter, loc, rhs, 1));

// ----------------------------------------------------------------
// 4. Create a zero-initialised init tensor for DPS output.
// ----------------------------------------------------------------
Value init = createZeroTensor(rewriter, loc, resultBuiltinTy, dynSizes);

// ----------------------------------------------------------------
// 5. Emit mips.matmul on builtin tensors.
// ----------------------------------------------------------------
Value result =
IREE::MIPS::MatmulOp::create(rewriter, loc, TypeRange{resultBuiltinTy},
lhs, rhs, init)
.getResult();

// ----------------------------------------------------------------
// 6. Cast result back to ValueTensorType so downstream torch passes can
// still operate on it until the type finalisation pass runs.
// ----------------------------------------------------------------
Value torchResult = torch::TorchConversion::FromBuiltinTensorOp::create(
rewriter, loc, resultTorchTy, result);

rewriter.replaceOp(op, torchResult);
return success();
}
};

//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//

struct ConvertTorchToMIPSPass
: impl::ConvertTorchToMIPSPassBase<ConvertTorchToMIPSPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::MIPS::MIPSDialect,
torch::TorchConversion::TorchConversionDialect,
arith::ArithDialect, tensor::TensorDialect,
linalg::LinalgDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<ConvertAtenMmToMIPSMatmul>(context);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

} // namespace
} // namespace mlir::iree_compiler::TorchInput
5 changes: 5 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ void createTorchToIREEPipeline(
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToTensorPass());
pm.addNestedPass<func::FuncOp>(
TorchInput::createConvertTorchUnstructuredToLinalgExtPass());
// MIPS: When enabled, intercept aten.mm before the standard torch->linalg
// pass and route it through mips.matmul -> func.call @my_matmul_kernel.
if (options.useMIPSMatmul) {
pm.addNestedPass<func::FuncOp>(TorchInput::createConvertTorchToMIPSPass());
}
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToLinalgPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(torch::createConvertTorchToSCFPass());
Expand Down
6 changes: 6 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ struct TorchToIREELoweringPipelineOptions
"program inputs. This buffer will be used for storing transient "
"memory and must be provided by the user."),
llvm::cl::init(false)};
Option<bool> useMIPSMatmul{
*this, "use-mips-matmul",
llvm::cl::desc("If enabled, lowers torch.aten.mm through the MIPS "
"custom dialect (mips.matmul) instead of the standard "
"torch->linalg path."),
llvm::cl::init(false)};
};

// Creates a pipeline that lowers from the torch backend contract to IREE.
Expand Down
10 changes: 10 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def ConvertTorchUnstructuredToLinalgExtPass :
let summary = "Convert unstructured Torch ops to LinalgExt ops";
}

def ConvertTorchToMIPSPass :
InterfacePass<"torch-iree-to-mips-matmul", "mlir::FunctionOpInterface"> {
let summary = "Convert torch.aten.mm to mips.matmul";
let description = [{
Intercepts torch.aten.mm before the standard torch->linalg conversion and
replaces it with mips.matmul. The mips.matmul op is later lowered to a
func.call @my_matmul_kernel by LowerMIPSToFuncCallPass after bufferization.
}];
}

def SetStrictSymbolicShapesPass :
InterfacePass<"torch-iree-set-strict-symbolic-shapes", "mlir::FunctionOpInterface"> {
let summary = "Adds the attribute indicating strict symbolic shapes in Torch IR";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ iree_lit_test_suite(
"attention.mlir"
"bind_symbolic_shapes.mlir"
"bitcast_tensor.mlir"
"convert_torch_to_mips.mlir"
"func_conversion.mlir"
"func_conversion_invalid.mlir"
"func_conversion_transients.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline="builtin.module(func.func(torch-iree-to-mips-matmul))" \
// RUN: %s | FileCheck %s

// ─────────────────────────────────────────────────────────────────────────────
// Static-shape: torch.aten.mm on f32 tensors → mips.matmul
// ─────────────────────────────────────────────────────────────────────────────

// CHECK-LABEL: func.func @mm_static
// CHECK: torch_c.to_builtin_tensor {{.*}} -> tensor<4x8xf32>
// CHECK: torch_c.to_builtin_tensor {{.*}} -> tensor<8x4xf32>
// CHECK: mips.matmul {{.*}} : tensor<4x8xf32>, tensor<8x4xf32>, tensor<4x4xf32> -> tensor<4x4xf32>
// CHECK-NOT: torch.aten.mm
func.func @mm_static(%A: !torch.vtensor<[4,8],f32>,
%B: !torch.vtensor<[8,4],f32>)
-> !torch.vtensor<[4,4],f32> {
%0 = torch.aten.mm %A, %B
: !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,4],f32>
-> !torch.vtensor<[4,4],f32>
return %0 : !torch.vtensor<[4,4],f32>
}

// ─────────────────────────────────────────────────────────────────────────────
// Non-f32 (i32) should be left untouched (pattern rejects non-f32 dtypes).
// ─────────────────────────────────────────────────────────────────────────────

// CHECK-LABEL: func.func @mm_i32_unchanged
// CHECK-NOT: mips.matmul
// CHECK: torch.aten.mm
func.func @mm_i32_unchanged(%A: !torch.vtensor<[4,8],si32>,
%B: !torch.vtensor<[8,4],si32>)
-> !torch.vtensor<[4,4],si32> {
%0 = torch.aten.mm %A, %B
: !torch.vtensor<[4,8],si32>, !torch.vtensor<[8,4],si32>
-> !torch.vtensor<[4,4],si32>
return %0 : !torch.vtensor<[4,4],si32>
}
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ iree_cc_library(
iree::compiler::Dialect::LinalgExt::IR
iree::compiler::Dialect::LinalgExt::Transforms
iree::compiler::Dialect::LinalgExt::Utils
iree::compiler::Dialect::MIPS::Transforms
iree::compiler::Dialect::TensorExt::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
#include "iree/compiler/Codegen/Utils/CodegenOptions.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "iree/compiler/Dialect/MIPS/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/Transforms/Passes.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -518,6 +519,10 @@ static void addLowerToLLVMPasses(OpPassManager &modulePassManager,
FunctionLikeNest(modulePassManager)
.addPass(createEraseHALDescriptorTypeFromMemRefPass);

// mips.matmul is eliminated during One-Shot Bufferize (func.call emitted
// directly by MIPSBufferizableOpInterface). This pass is now a no-op.
modulePassManager.addPass(IREE::MIPS::createLowerMIPSToFuncCallPass());

// Lower `ukernel.*` ops to function calls
modulePassManager.addPass(createLowerUKernelOpsToCallsPass());

Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Dialect/MIPS/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Recursively picks up IR/ and Transforms/ subdirectories.
iree_add_all_subdirs()
Loading