diff --git a/graph_net_bench/torch/backend/flagtree_backend.py b/graph_net_bench/torch/backend/flagtree_backend.py new file mode 100644 index 000000000..c63a1b35c --- /dev/null +++ b/graph_net_bench/torch/backend/flagtree_backend.py @@ -0,0 +1,31 @@ +import torch +from .graph_compiler_backend import GraphCompilerBackend + +try: + import flagtree +except ImportError: + flagtree = None + + +class FlagtreeBackend(GraphCompilerBackend): + def __init__(self, config): + super().__init__(config) + self.flagtree_backend = None + + def __call__(self, model): + if flagtree is None: + raise ImportError("flagtree not installed") + + if self.flagtree_backend is None: + self.flagtree_backend = flagtree.create_backend() + + return self.flagtree_backend.compile(model) + + def synchronize(self): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def version(self): + if flagtree is None: + return "not installed" + return flagtree.__version__ diff --git a/graph_net_bench/torch/eval_backend_perf.py b/graph_net_bench/torch/eval_backend_perf.py index 3fd6db3ff..ecd37c65f 100644 --- a/graph_net_bench/torch/eval_backend_perf.py +++ b/graph_net_bench/torch/eval_backend_perf.py @@ -47,7 +47,7 @@ def get_hardward_name(device): def get_compiler_version(compiler): if compiler in ["inductor", "nope", "unstable_to_stable"]: return torch.__version__ - elif compiler in ["tvm", "xla", "tensorrt", "bladedisc"]: + elif compiler in ["tvm", "flagtree", "xla", "tensorrt", "bladedisc"]: # Assuming compiler object has a version attribute return f"{compiler.capitalize()} {compiler.version}" return "unknown" diff --git a/graph_net_bench/torch/test_compiler.py b/graph_net_bench/torch/test_compiler.py index 0923e19d6..e41d9aa42 100755 --- a/graph_net_bench/torch/test_compiler.py +++ b/graph_net_bench/torch/test_compiler.py @@ -16,6 +16,7 @@ import base64 from graph_net_bench.torch.backend.graph_compiler_backend import GraphCompilerBackend from graph_net_bench.torch.backend.tvm_backend import TvmBackend +from graph_net_bench.torch.backend.flagtree_backend import FlagtreeBackend from graph_net_bench.torch.backend.xla_backend import XlaBackend from graph_net_bench.torch.backend.inductor_backend import InductorBackend from graph_net_bench.torch.backend.tensorrt_backend import TensorRTBackend @@ -37,6 +38,7 @@ compiler_backend_name2class = { "tvm": TvmBackend, + "flagtree": FlagtreeBackend, "xla": XlaBackend, "inductor": InductorBackend, "tensorrt": TensorRTBackend, @@ -70,7 +72,7 @@ def get_hardward_name(args): def get_compile_framework_version(args): if args.compiler in ["inductor", "nope", "unstable_to_stable"]: return torch.__version__ - elif args.compiler in ["tvm", "xla", "tensorrt", "bladedisc"]: + elif args.compiler in ["tvm", "flagtree", "xla", "tensorrt", "bladedisc"]: # Assuming compiler object has a version attribute return f"{args.compiler.capitalize()} {args.compiler.version}" return "unknown"