From 6851d825cb7f33093a4880b706591d42e8354571 Mon Sep 17 00:00:00 2001 From: qti-ashimaj Date: Wed, 24 Dec 2025 11:08:59 +0530 Subject: [PATCH] add flag to apply DeduplicateHashedInitializersPass --- olive/passes/onnx/graph_surgeries.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 1eb6b799f..ae11b092e 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -2074,6 +2074,13 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon required=True, description="List of surgeries to apply, each with its type and parameters", ), + "remove_duplicate_initializers": PassConfigParam( + type_=bool, + default_value=True, + description=""" + Apply DeduplicateHashedInitializersPass after graph surgeries in case graph surgeries add duplicated initializers + """, + ), **get_external_data_config(), } @@ -2089,8 +2096,10 @@ def _run_for_config( surgeon_instance = self.init_surgeon_instance(surgery) onnx_model = surgeon_instance(onnx_model) - deduped_model = DeduplicateHashedInitializersPass()(ir.from_proto(onnx_model)).model - return model_proto_to_olive_model(ir.to_proto(deduped_model), output_model_path, config) + if config.remove_duplicate_initializers: + deduped_model = DeduplicateHashedInitializersPass()(ir.from_proto(onnx_model)).model + return model_proto_to_olive_model(ir.to_proto(deduped_model), output_model_path, config) + return model_proto_to_olive_model(onnx_model, output_model_path, config) def init_surgeon_instance(self, surgery): surgeon_name = surgery.get("surgeon").lower()