From 7b43173241c249e22553d4a00c0ed35be7dac2cb Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Tue, 27 Jan 2026 19:48:23 +0800 Subject: [PATCH 1/3] Add FP32_ONLY_FUNCS op to fix dtype generalization pass --- .../dtype_generalization_pass.py | 45 ++++++++++++++++--- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 718b39197..fbf186557 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -32,6 +32,19 @@ "bmm", } +FP32_ONLY_FUNCS = { + torch.nn.functional.softmax, + torch.nn.functional.layer_norm, + torch.nn.functional.group_norm, + torch.nn.functional.batch_norm, + torch.nn.functional.embedding, + torch.exp, + torch.log, + torch.pow, + torch.sigmoid, + torch.tanh, + torch.conv_transpose2d, +} class ConcretePass(DtypeGeneralizationPass): """ @@ -107,7 +120,7 @@ def create_get_attr(node: fx.Node) -> fx.Node: return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node - def create_new_args(node: fx.Node) -> list: + def create_new_args(node: fx.Node, target_dtype: torch.dtype) -> list: """new_args of node with dtype conversion if needed.""" new_args = [] @@ -115,7 +128,10 @@ def create_new_args(node: fx.Node) -> list: if isinstance(arg, fx.Node): mapped = val_map[arg] if self._is_float32_tensor(arg): - mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + if target_dtype == torch.float32: + mapped = new_graph.call_method("to", (mapped, torch.float32)) + elif target_dtype == self.torch_dtype: + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) new_args.append(mapped) else: new_args.append(arg) @@ -123,10 +139,13 @@ def create_new_args(node: fx.Node) -> list: def create_call_function(node: fx.Node) -> fx.Node: """Create a call_function node with dtype conversion if needed.""" - if node.target not in AMP_CALL_FUNCTION: - return new_graph.node_copy(node, lambda x: val_map[x]) + require_fp32 = is_fp32_node(node) + target_dtype = torch.float32 if require_fp32 else self.torch_dtype + + if node.target not in AMP_CALL_FUNCTION and not require_fp32: + return new_graph.node_copy(node, lambda x: val_map[x]) - new_args = create_new_args(node) + new_args = create_new_args(node, target_dtype) new_kwargs = { k: val_map[v] if isinstance(v, fx.Node) else v @@ -140,10 +159,14 @@ def create_call_function(node: fx.Node) -> fx.Node: ) def create_call_method(node: fx.Node) -> fx.Node: - if node.target not in AMP_CALL_METHOD: + """Create a call_method node with dtype conversion if needed.""" + require_fp32 = is_fp32_node(node) + target_dtype = torch.float32 if require_fp32 else self.torch_dtype + + if node.target not in AMP_CALL_METHOD and not require_fp32: return new_graph.node_copy(node, lambda x: val_map[x]) - new_args = create_new_args(node) + new_args = create_new_args(node, target_dtype) new_kwargs = { k: (val_map[v] if isinstance(v, fx.Node) else v) @@ -156,6 +179,14 @@ def create_call_method(node: fx.Node) -> fx.Node: new_kwargs, ) + def is_fp32_node(node: fx.Node) -> bool: + """Check if a node of float32 only op.""" + if node.op == 'call_function': + return node.target in FP32_ONLY_FUNCS + elif node.op == 'call_method': + return node.target in AMP_CALL_METHOD + return False + for node in gm.graph.nodes: if node.op == "placeholder": val_map[node] = create_placeholder(node) From e269e2e4eb6108807f31209c43aa9ad67da09f42 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Wed, 28 Jan 2026 15:25:39 +0800 Subject: [PATCH 2/3] fix bug of invalid dtype for bias --- .../dtype_generalization_pass.py | 73 +++++++------------ 1 file changed, 28 insertions(+), 45 deletions(-) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index fbf186557..91400d1e0 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -32,19 +32,6 @@ "bmm", } -FP32_ONLY_FUNCS = { - torch.nn.functional.softmax, - torch.nn.functional.layer_norm, - torch.nn.functional.group_norm, - torch.nn.functional.batch_norm, - torch.nn.functional.embedding, - torch.exp, - torch.log, - torch.pow, - torch.sigmoid, - torch.tanh, - torch.conv_transpose2d, -} class ConcretePass(DtypeGeneralizationPass): """ @@ -107,6 +94,10 @@ def create_placeholder(node: fx.Node) -> fx.Node: """Create a placeholder node with dtype conversion if needed.""" new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) if self._is_float32_tensor(node): + attr_name = str(node.target) + if self.should_preserve_weight(attr_name): + return new_node + return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node @@ -120,7 +111,7 @@ def create_get_attr(node: fx.Node) -> fx.Node: return new_graph.call_method("to", args=(new_node, self.torch_dtype)) return new_node - def create_new_args(node: fx.Node, target_dtype: torch.dtype) -> list: + def create_new_args(node: fx.Node) -> list: """new_args of node with dtype conversion if needed.""" new_args = [] @@ -128,29 +119,35 @@ def create_new_args(node: fx.Node, target_dtype: torch.dtype) -> list: if isinstance(arg, fx.Node): mapped = val_map[arg] if self._is_float32_tensor(arg): - if target_dtype == torch.float32: - mapped = new_graph.call_method("to", (mapped, torch.float32)) - elif target_dtype == self.torch_dtype: - mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) new_args.append(mapped) else: new_args.append(arg) return new_args + def create_new_kwargs(node: fx.Node) -> dict: + """new_kwargs of node with dtype conversion if needed.""" + new_kwargs = {} + + for k, v in node.kwargs.items(): + if isinstance(v, fx.Node): + mapped = val_map[v] + if self._is_float32_tensor(v): + mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) + else: + new_kwargs[k] = mapped + else: + new_kwargs[k] = v + return new_kwargs + def create_call_function(node: fx.Node) -> fx.Node: """Create a call_function node with dtype conversion if needed.""" - require_fp32 = is_fp32_node(node) - target_dtype = torch.float32 if require_fp32 else self.torch_dtype - - if node.target not in AMP_CALL_FUNCTION and not require_fp32: - return new_graph.node_copy(node, lambda x: val_map[x]) + if node.target not in AMP_CALL_FUNCTION: + return new_graph.node_copy(node, lambda x: val_map[x]) - new_args = create_new_args(node, target_dtype) + new_args = create_new_args(node) - new_kwargs = { - k: val_map[v] if isinstance(v, fx.Node) else v - for k, v in node.kwargs.items() - } + new_kwargs = create_new_kwargs(node) return new_graph.call_function( node.target, @@ -160,18 +157,12 @@ def create_call_function(node: fx.Node) -> fx.Node: def create_call_method(node: fx.Node) -> fx.Node: """Create a call_method node with dtype conversion if needed.""" - require_fp32 = is_fp32_node(node) - target_dtype = torch.float32 if require_fp32 else self.torch_dtype - - if node.target not in AMP_CALL_METHOD and not require_fp32: + if node.target not in AMP_CALL_METHOD: return new_graph.node_copy(node, lambda x: val_map[x]) - new_args = create_new_args(node, target_dtype) + new_args = create_new_args(node) - new_kwargs = { - k: (val_map[v] if isinstance(v, fx.Node) else v) - for k, v in node.kwargs.items() - } + new_kwargs = create_new_kwargs(node) return new_graph.call_method( node.target, @@ -179,14 +170,6 @@ def create_call_method(node: fx.Node) -> fx.Node: new_kwargs, ) - def is_fp32_node(node: fx.Node) -> bool: - """Check if a node of float32 only op.""" - if node.op == 'call_function': - return node.target in FP32_ONLY_FUNCS - elif node.op == 'call_method': - return node.target in AMP_CALL_METHOD - return False - for node in gm.graph.nodes: if node.op == "placeholder": val_map[node] = create_placeholder(node) From 3e76ddecc75af7daf8f4db98c61c023d067ba742 Mon Sep 17 00:00:00 2001 From: WHoutstanding Date: Thu, 29 Jan 2026 16:48:38 +0800 Subject: [PATCH 3/3] fix FP32_ONLY_OPS --- .../dtype_generalization_pass.py | 78 +++++++++++++++---- 1 file changed, 61 insertions(+), 17 deletions(-) diff --git a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py index 91400d1e0..46b77b5a4 100644 --- a/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py +++ b/graph_net/torch/dtype_gen_passes/dtype_generalization_pass.py @@ -32,6 +32,12 @@ "bmm", } +FP32_SENSITIVE_OPS = { + torch.nn.functional.layer_norm: ({2, 3, 4}, {"weight", "bias", "eps"}), + torch.nn.functional.group_norm: ({2, 3, 4}, {"weight", "bias", "eps"}), + torch.nn.functional.batch_norm: ({1, 2, 3, 4}, {"running_mean", "running_var", "weight", "bias"}), + torch.nn.functional.embedding: ({0}, {"weight"}), +} class ConcretePass(DtypeGeneralizationPass): """ @@ -78,6 +84,27 @@ def _node_need_rewrite(self, node: fx.Node) -> bool: return False + def _analyze_preserved_nodes(self, graph: fx.Graph) -> set[fx.Node]: + """预扫描图:找到所有被 FP32 敏感算子使用的参数节点。""" + preserved_nodes = set() + + for node in graph.nodes: + if node.op != "call_function": + continue + + if node.target in FP32_SENSITIVE_OPS: + target_indices, target_kwargs = FP32_SENSITIVE_OPS[node.target] + + for i, arg in enumerate(node.args): + if i in target_indices and isinstance(arg, fx.Node): + preserved_nodes.add(arg) + + for k, v in node.kwargs.items(): + if k in target_kwargs and isinstance(v, fx.Node): + preserved_nodes.add(v) + + return preserved_nodes + def rewrite(self, gm: fx.GraphModule) -> fx.GraphModule: """ Rewrite the graph to convert dtypes. @@ -89,27 +116,41 @@ def rewrite(self, gm: fx.GraphModule) -> fx.GraphModule: """ new_graph = fx.Graph() val_map = {} - + preserved_nodes = self._analyze_preserved_nodes(gm.graph) + def create_placeholder(node: fx.Node) -> fx.Node: """Create a placeholder node with dtype conversion if needed.""" new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) - if self._is_float32_tensor(node): - attr_name = str(node.target) - if self.should_preserve_weight(attr_name): - return new_node - - return new_graph.call_method("to", args=(new_node, self.torch_dtype)) - return new_node + # if self._is_float32_tensor(node): + # attr_name = str(node.target) + # if self.should_preserve_weight(attr_name): + # return new_node + # return new_graph.call_method("to", args=(new_node, self.torch_dtype)) + # return new_node + if not self._is_float32_tensor(node): + return new_node + + if node in preserved_nodes: + return new_node + + return new_graph.call_method("to", args=(new_node, self.torch_dtype)) def create_get_attr(node: fx.Node) -> fx.Node: """Create a get_attr node with dtype conversion if needed.""" new_node = new_graph.node_copy(node, lambda x: val_map.get(x, x)) - attr_name = str(node.target) - if self._is_float32_tensor(node) and not self.should_preserve_weight( - attr_name - ): - return new_graph.call_method("to", args=(new_node, self.torch_dtype)) - return new_node + # attr_name = str(node.target) + # if self._is_float32_tensor(node) and not self.should_preserve_weight( + # attr_name + # ): + # return new_graph.call_method("to", args=(new_node, self.torch_dtype)) + # return new_node + if not self._is_float32_tensor(node): + return new_node + + if node in preserved_nodes: + return new_node + + return new_graph.call_method("to", args=(new_node, self.torch_dtype)) def create_new_args(node: fx.Node) -> list: """new_args of node with dtype conversion if needed.""" @@ -128,16 +169,16 @@ def create_new_args(node: fx.Node) -> list: def create_new_kwargs(node: fx.Node) -> dict: """new_kwargs of node with dtype conversion if needed.""" new_kwargs = {} - + for k, v in node.kwargs.items(): if isinstance(v, fx.Node): mapped = val_map[v] if self._is_float32_tensor(v): mapped = new_graph.call_method("to", (mapped, self.torch_dtype)) - else: + else: new_kwargs[k] = mapped else: - new_kwargs[k] = v + new_kwargs[k] = v return new_kwargs def create_call_function(node: fx.Node) -> fx.Node: @@ -186,6 +227,9 @@ def create_call_method(node: fx.Node) -> fx.Node: gm.graph = new_graph gm.recompile() + with open("output.txt", "w", encoding="utf-8") as f: + print(gm.graph, file=f) + return gm def _is_float32_tensor(self, node: fx.Node) -> bool: