Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ To speed up the build process, you can set up [ccache](https://ccache.dev/downlo
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache
```

To enable parallelization with OpenMP runtime, add the following flag to the command above:

```shell
-DLLVM_ENABLE_RUNTIMES=openmp
```

Run the following to ensure the MPACT compiler builds and runs correctly.

```shell
Expand Down
39 changes: 33 additions & 6 deletions python/mpact/mpactbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,12 @@ def invoke(*args):
"func.func(refback-munge-memref-copy)",
"func.func(convert-linalg-to-loops)",
"func.func(lower-affine)",
"convert-scf-to-openmp{{{omp_options}}}",
"convert-scf-to-cf",
"func.func(refback-expand-ops-for-llvm)",
"func.func(arith-expand)",
"func.func(convert-math-to-llvm)",
"convert-openmp-to-llvm",
"convert-math-to-libm",
"expand-strided-metadata",
"finalize-memref-to-llvm",
Expand All @@ -276,17 +278,27 @@ def invoke(*args):
class MpactBackendCompiler:
"""Main entry-point for the MPACT backend compiler."""

def __init__(self, opt_level, use_sp_it):
def __init__(self, opt_level, use_sp_it, parallel, enable_ir_printing, num_threads):
self.opt_level = opt_level
self.use_sp_it = use_sp_it
self.parallel = parallel
self.enable_ir_printing = enable_ir_printing
self.num_threads = num_threads

def compile(self, imported_module: Module) -> MpactCompiledArtifact:
sp_options = (
"sparse-emit-strategy=sparse-iterator"
if self.use_sp_it
else "vl=16 enable-simd-index32"
)
LOWERING_PIPELINE = LOWERING_PIPELINE_TEMPLATE.format(sp_options=sp_options)
omp_options = f"num-threads={self.num_threads}"
# TODO: enable the parallelization strategy
# once MLIR bump is completed.
# if self.parallel:
# sp_options += f" parallelization-strategy={self.parallel}"
LOWERING_PIPELINE = LOWERING_PIPELINE_TEMPLATE.format(
sp_options=sp_options, omp_options=omp_options
)
"""Compiles an imported module, with a flat list of functions.
The module is expected to be in linalg-on-tensors + scalar code form.

Expand All @@ -299,7 +311,7 @@ def compile(self, imported_module: Module) -> MpactCompiledArtifact:
imported_module,
LOWERING_PIPELINE,
"Lowering Linalg-on-Tensors IR to LLVM with MpactBackendCompiler",
enable_ir_printing=False,
enable_ir_printing=self.enable_ir_printing,
)
return imported_module

Expand Down Expand Up @@ -461,7 +473,16 @@ def export_and_import(f, *args, **kwargs):
return fx_importer.module


def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs):
def mpact_jit_compile(
f,
*args,
opt_level=2,
use_sp_it=False,
parallel="none",
enable_ir_printing=False,
num_threads=1,
**kwargs,
):
"""This method compiles the given callable using the MPACT backend."""
# Import module and lower into Linalg IR.
module = export_and_import(f, *args, **kwargs)
Expand All @@ -473,10 +494,16 @@ def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs):
"torch-backend-to-linalg-on-tensors-backend-pipeline)"
),
"Lowering TorchFX IR -> Linalg IR",
enable_ir_printing=False,
enable_ir_printing=enable_ir_printing,
)
# Compile with MPACT backend compiler.
backend = MpactBackendCompiler(opt_level=opt_level, use_sp_it=use_sp_it)
backend = MpactBackendCompiler(
opt_level=opt_level,
use_sp_it=use_sp_it,
parallel=parallel,
enable_ir_printing=enable_ir_printing,
num_threads=num_threads,
)
compiled = backend.compile(module)
invoker = backend.load(compiled)
return invoker, f
Expand Down
5 changes: 5 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@ add_lit_testsuite(check-mpact "Running the MPACT regression tests"
)
set_target_properties(check-mpact PROPERTIES FOLDER "Tests")

# TODO: find omp library.
# find_package(OpenMP REQUIRED)
# add_compile_options(${OpenMP_CXX_FLAGS})
# target_link_libraries(check-mpact OpenMP::OpenMP_CXX)

add_lit_testsuites(MPACT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
47 changes: 47 additions & 0 deletions test/python/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# RUN: %PYTHON -s %s 2>&1 | FileCheck %s

import gc
import sys
import torch
import numpy as np

from mpact.mpactbackend import mpact_jit

from mpact.models.kernels import MMNet


def run_test(f, *args, **kwargs):
print("TEST:", f.__name__, file=sys.stderr)
f(*args, **kwargs)
gc.collect()


net = MMNet()

# Construct dense and sparse matrices.
X = torch.arange(0, 16, dtype=torch.float32).view(4, 4)
Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
A = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 0.0, 0.0, 0.0],
[3.0, 0.0, 0.0, 0.0],
],
dtype=torch.float32,
)
S = A.to_sparse_csr()

# Run it with MPACT.
# TODO: enable the check test.
# C-HECK: omp.parallel
# CHECK: openmp
run_test(
mpact_jit,
net,
X,
Y,
parallel="any-storage-any-loop",
enable_ir_printing=True,
num_threads=10,
)