From fa8b79e8b5df0734f250e6023f3a0b98c4af5bd5 Mon Sep 17 00:00:00 2001 From: Jianyi Cheng Date: Fri, 17 May 2024 16:51:12 +0000 Subject: [PATCH] Added insert fork pass --- .../graph/transforms/utils/insert_fork.py | 30 +++++++++++++++++++ .../verilog/test_emit_verilog_bert.py | 2 ++ 2 files changed, 32 insertions(+) create mode 100644 src/chop/passes/graph/transforms/utils/insert_fork.py diff --git a/src/chop/passes/graph/transforms/utils/insert_fork.py b/src/chop/passes/graph/transforms/utils/insert_fork.py new file mode 100644 index 000000000..1de891180 --- /dev/null +++ b/src/chop/passes/graph/transforms/utils/insert_fork.py @@ -0,0 +1,30 @@ +import torch + + +def insert_fork_transform_pass(graph, pass_args={}): + """Insert hardware-explicit forks into the mase graph + + :param graph: a MaseGraph + :type graph: MaseGraph + :param pass_args: this pass requires additional arguments which is explained below, defaults to {} + :type pass_args: _type_, optional + :return: return a tuple of a MaseGraph and an empty dict (no additional info to return) + :rtype: tuple(MaseGraph, Dict) + """ + + logger.info("Inserting forks...") + + nodes_to_fork = [] + for node in graph.fx_graph.nodes: + user_count = 0 + for u in node.users.keys(): + if u.meta["mase"].parameters["hardware"]["is_implicit"]: + user_count += 1 + if user_count > 1: + nodes_to_fork.append(node) + + for node in nodes_to_fork: + graph.fx_graph.inserting_after(node) + graph.fx_graph.create_node("call_module", torch.nn.Identity) + + return graph, _ diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py index 0f6f8aa50..91ad2a4ac 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_bert.py @@ -148,6 +148,8 @@ def test_emit_verilog_bert(): # i += 1 + mg, _ = passes.insert_fork_transform_pass(mg) + mg, _ = passes.emit_verilog_top_transform_pass(mg) # mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg)