From 8deac8d9ebb0da28c3b9851a3ad2269c7e771e8e Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 8 Jan 2026 15:52:20 -0800 Subject: [PATCH 01/21] enable dyn shapes for pointwise and reduce fusions --- src/fuse_pointwise.cpp | 16 +++++++++------ src/fuse_reduce.cpp | 8 ++++++-- src/include/migraphx/op/pointwise.hpp | 7 +++++-- src/include/migraphx/shape.hpp | 2 ++ src/shape.cpp | 13 ++++++++++++ test/fuse_pointwise.cpp | 29 +++++++++++++++++++++++++++ test/fuse_reduce.cpp | 25 +++++++++++++++++++++++ test/include/reduce.hpp | 13 +++++++++--- 8 files changed, 100 insertions(+), 13 deletions(-) diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 42c95514555..722597a63b2 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -46,7 +46,7 @@ static literal get_scalar(instruction_ref ins) if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name())) return get_scalar(ins->inputs().front()); const auto& s = ins->get_shape(); - if(s.elements() != 1 and not(s.scalar())) + if(s.dynamic() or (s.elements() != 1 and not(s.scalar()))) return {}; if(not ins->can_eval()) return {}; @@ -340,16 +340,20 @@ struct pointwise_reshape : rewrite_reshapes_base static std::string name() { return "pointwise"; } }; -struct pointwise_broadcast_pointwise +struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes { auto matcher() const { + auto pointwise = match::name("pointwise")(match::used_once()).bind("x"); auto broadcast_pointwise = - match::name("multibroadcast")( - match::used_once(), - match::args(match::name("pointwise")(match::used_once()).bind("x"))) + match::name("multibroadcast")(match::used_once(), match::args(pointwise)) .bind("broadcast"); - return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise)); + auto dyn_broadcast_pointwise = match::name("multibroadcast")(match::used_once(), + match::nargs(2), + match::arg(1)(pointwise)) + .bind("broadcast"); + return match::name("pointwise")(match::any_of[match::inputs()]( + match::any_of(broadcast_pointwise, dyn_broadcast_pointwise))); } void apply(module& m, const match::matcher_result& r) const diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 8d1a7ff39d6..f1c7329e402 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -63,15 +63,19 @@ struct fused_reduce if(not sm->bypass()) MIGRAPHX_THROW("fused_reduce: bypass flag is not set"); auto names = sm->get_parameter_names(); - check_shapes{inputs, *this}.has(names.size()).same_ndims(); + check_shapes{inputs, *this, true}.has(names.size()).same_ndims(); std::sort(names.begin(), names.end()); auto shapes = sm->get_parameter_shapes(); // Check dimension matches for each input if(not equal(names, inputs, [&](const auto& name, const auto& input) { - return shapes.at(name).lens() == input.lens(); + auto s = shapes.at(name); + return shape::same_lens(input, s); })) MIGRAPHX_THROW("Input dimension does not match the submodule."); + if(sm->get_output_shapes().front().dynamic()) + return sm->get_output_shapes().front(); + return shape::from_permutation(sm->get_output_shapes().front().type(), sm->get_output_shapes().front().lens(), find_permutation(inputs)); diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 51ca75ee92b..03f16ca7d01 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -61,7 +61,10 @@ struct pointwise MIGRAPHX_THROW("pointwise should have at least one input"); auto* pm = mods.front(); auto pnames = pm->get_parameter_names(); - check_shapes{inputs, *this}.has(pnames.size()).same_dims(); + check_shapes{inputs, *this, true}.has(pnames.size()).same_dims(); + + std::vector scalar_const_out_lens = + inputs.front().dynamic() ? std::vector{} : inputs.front().lens(); const auto rank = inputs.front().ndim(); const bool has_broadcasts = @@ -69,7 +72,7 @@ struct pointwise auto result = pm->compute_shapes( (rank > 1 and has_broadcasts) ? remove_broadcasts(inputs) : inputs, - {.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()}); + {.name = name(), .strict_type = true, .scalar_const_out_lens = scalar_const_out_lens}); if(result.size() == 1) return result.front(); return shape{result}; diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 22f6d1663d6..d8c9c5ffed2 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -359,6 +359,8 @@ struct MIGRAPHX_EXPORT shape MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const shape& x); + static bool same_lens(const shape& x, const shape& y); + template struct as { diff --git a/src/shape.cpp b/src/shape.cpp index 678bf5b53f0..c7a96ed9efd 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -828,6 +828,19 @@ std::ostream& operator<<(std::ostream& os, const shape& x) return os; } +bool shape::same_lens(const shape& x, const shape& y) +{ + if(x.dynamic() and y.dynamic()) + { + return x.dyn_dims() == y.dyn_dims(); + } + else if(x.dynamic() or y.dynamic()) + { + MIGRAPHX_THROW("SHAPE: same_lens() called on mixed dynamic and static shapes"); + } + return x.lens() == y.lens(); +} + shape::type_t shape::parse_type(const std::string& s) { static const std::unordered_map m = { diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 44e08e63c65..62751b1d831 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -67,6 +67,35 @@ TEST_CASE(single) EXPECT(p1 == p2); } +TEST_CASE(single_dyn) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = mm->add_instruction(migraphx::make_op("add"), pass, z); + mm->add_return({add2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto z = mm->add_parameter("z", s); + auto add1 = add_pointwise(p2, "main:pointwise0", {x, y}, single_pointwise("add")); + auto pass = mm->add_instruction(pass_op{}, add1); + auto add2 = add_pointwise(p2, "main:pointwise1", {pass, z}, single_pointwise("add")); + mm->add_return({add2}); + } + EXPECT(p1 == p2); +} + TEST_CASE(double_add) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 8bf01bc7789..4f07a8be63a 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -63,6 +63,31 @@ TEST_CASE(single) EXPECT(p1 == p2); } +TEST_CASE(single_dyn) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x); + auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), y); + mm->add_return({rsum1, rsum2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum")); + auto rsum2 = add_reduce(p2, "main:reduce_sum1", {y}, {1}, single_reduce("reduce_sum")); + mm->add_return({rsum1, rsum2}); + } + EXPECT(p1 == p2); +} + TEST_CASE(pointwise_reduce) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; diff --git a/test/include/reduce.hpp b/test/include/reduce.hpp index 81d87566893..59d62b289fe 100644 --- a/test/include/reduce.hpp +++ b/test/include/reduce.hpp @@ -61,9 +61,16 @@ migraphx::module_ref add_reduce_module(migraphx::program& p, rm->set_bypass(); std::vector params; std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { - return rm->add_parameter( - "x" + std::to_string(params.size()), - migraphx::shape{input->get_shape().type(), input->get_shape().lens()}); + migraphx::shape s; + if(input->get_shape().dynamic()) + { + s = input->get_shape(); + } + else + { + s = migraphx::shape{input->get_shape().type(), input->get_shape().lens()}; + } + return rm->add_parameter("x" + std::to_string(params.size()), s); }); auto r = f(rm, params, axes); auto_add_return(rm, r); From d83f2a028fe26840897e32a00c9d9c26e92e7f22 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 9 Jan 2026 09:10:20 -0800 Subject: [PATCH 02/21] format --- src/fuse_reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index f1c7329e402..b8d8dc9833f 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -75,7 +75,7 @@ struct fused_reduce if(sm->get_output_shapes().front().dynamic()) return sm->get_output_shapes().front(); - + return shape::from_permutation(sm->get_output_shapes().front().type(), sm->get_output_shapes().front().lens(), find_permutation(inputs)); From 370e386c5b265f0a5f31b1723db796dd9137528e Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 9 Jan 2026 09:14:34 -0800 Subject: [PATCH 03/21] License update --- src/fuse_pointwise.cpp | 2 +- src/fuse_reduce.cpp | 2 +- src/include/migraphx/op/pointwise.hpp | 2 +- src/include/migraphx/shape.hpp | 2 +- src/shape.cpp | 2 +- test/fuse_pointwise.cpp | 2 +- test/fuse_reduce.cpp | 2 +- test/include/reduce.hpp | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 722597a63b2..563c420de82 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index b8d8dc9833f..f15eeabeea8 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 03f16ca7d01..9d16879cf35 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index d8c9c5ffed2..64b510176c3 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/shape.cpp b/src/shape.cpp index c7a96ed9efd..6698ff126fc 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 62751b1d831..08a1950fc63 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 4f07a8be63a..589c3c07cb1 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/include/reduce.hpp b/test/include/reduce.hpp index 59d62b289fe..6583bb900eb 100644 --- a/test/include/reduce.hpp +++ b/test/include/reduce.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 0061b0acbca8e0bdc8389f578712262795a843db Mon Sep 17 00:00:00 2001 From: Shiv Date: Thu, 15 Jan 2026 15:52:12 -0800 Subject: [PATCH 04/21] add dynamic code object op --- src/include/migraphx/module.hpp | 8 +- src/include/migraphx/op/pointwise.hpp | 7 +- src/module.cpp | 72 ++++++- src/targets/gpu/compile_ops.cpp | 37 +--- .../include/migraphx/gpu/precompile_ops.hpp | 193 ++++++++++++++++++ src/targets/gpu/lowering.cpp | 26 ++- test/gpu/dynamic_code_object_op.cpp | 73 +++++++ 7 files changed, 373 insertions(+), 43 deletions(-) create mode 100644 src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp create mode 100644 test/gpu/dynamic_code_object_op.cpp diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index d68b2683e65..04b2dea9221 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -323,6 +323,12 @@ struct MIGRAPHX_EXPORT module void annotate(std::ostream& os, std::function a) const; std::vector get_sub_modules(bool shallow = false) const; + + /* Creates a new module with the same instructions but with different input parameter shapes. + Returns the new module by value without modifying the original. + */ + module with_static_shapes(const std::vector& input_shapes); + /* sorts the module in topological order aka reverse-post order (RPO) DFS order it takes last instruction or @return as the root and walks back the graph and moves inputs of the each instruction such that it appears before the instruction itself. diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 51ca75ee92b..03f16ca7d01 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -61,7 +61,10 @@ struct pointwise MIGRAPHX_THROW("pointwise should have at least one input"); auto* pm = mods.front(); auto pnames = pm->get_parameter_names(); - check_shapes{inputs, *this}.has(pnames.size()).same_dims(); + check_shapes{inputs, *this, true}.has(pnames.size()).same_dims(); + + std::vector scalar_const_out_lens = + inputs.front().dynamic() ? std::vector{} : inputs.front().lens(); const auto rank = inputs.front().ndim(); const bool has_broadcasts = @@ -69,7 +72,7 @@ struct pointwise auto result = pm->compute_shapes( (rank > 1 and has_broadcasts) ? remove_broadcasts(inputs) : inputs, - {.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()}); + {.name = name(), .strict_type = true, .scalar_const_out_lens = scalar_const_out_lens}); if(result.size() == 1) return result.front(); return shape{result}; diff --git a/src/module.cpp b/src/module.cpp index 4838d241904..7ead68fa272 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -717,7 +718,7 @@ std::vector module::compute_shapes(const std::vector& inputs, ins->get_shape().type_string() + " but passed " + ins_shapes[ins].type_string()); } - if(options.strict_lens and ins->get_shape().lens() != ins_shapes[ins].lens()) + if(not ins->get_shape().dynamic() and options.strict_lens and ins->get_shape().lens() != ins_shapes[ins].lens()) { MIGRAPHX_THROW(options.name + ": Mismatched lens: expected {" + to_string_range(ins->get_shape().lens()) + "} but passed {" + @@ -1466,6 +1467,73 @@ std::vector module::get_sub_modules(bool shallow) const return vec_modules; } +module module::with_static_shapes(const std::vector& input_shapes) +{ + // This routine creates a new module with the same instructions but with different input shapes. + // The sequence of instructions (operators and interconnectivity) is copied, but all input parameter shapes are replaced with new "input_shapes". + + // ensure input_shapes is the same length as the parameters. + auto param_names = this->get_parameter_names(); + assert(param_names.size() == input_shapes.size() && "Mismatch between input_shapes and parameter count"); + + // Make a mapping from the parameter names to the new shapes. + std::unordered_map shape_map; + for(std::size_t i = 0; i < param_names.size(); ++i) + shape_map[param_names[i]] = input_shapes[i]; + + module new_mod; + + std::unordered_map ins_map; + + // First, create parameters with new shapes in new_mod and fill ins_map for params + for(auto ins : iterator_for(*this)) + { + if(ins->name() == "@param") + { + auto pname = any_cast(ins->get_operator()).parameter; + assert(shape_map.count(pname) > 0); + ins_map[ins] = new_mod.add_parameter(pname, shape_map.at(pname)); + } + } + + // Copy remaining instructions (except parameters) in order + for(auto ins : iterator_for(*this)) + { + if(ins->name() == "@param") + continue; + + // Gather new input refs for this instruction + std::vector new_args; + for(auto arg : ins->inputs()) + new_args.push_back(ins_map.at(arg)); + + // Gather new module argument refs if present + std::vector new_mod_args; + for(auto modarg : ins->module_inputs()) + new_mod_args.push_back(modarg); // Modules are *not* recreated, just reused + + instruction_ref new_ins; + if(ins->name() == "@literal") + { + new_ins = new_mod.add_literal(ins->get_literal()); + } + else if(ins->name() == "@return") + { + new_ins = new_mod.add_return(new_args); + } + else + { + if(new_mod_args.empty()) + new_ins = new_mod.add_instruction(ins->get_operator(), new_args); + else + new_ins = new_mod.add_instruction(ins->get_operator(), new_args, new_mod_args); + } + ins_map[ins] = new_ins; + } + + return new_mod; +} + module& module::sort() { if(this->begin() == this->end()) diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 20ca81c3b1a..b984f5d4ee2 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -37,6 +37,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -45,42 +46,8 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_BENCHMARKING); -struct precompile_op -{ - operation op = op::identity{}; - std::size_t additional_args = 1; - bool ignore_modules = false; - std::optional output_shape = nullopt; - - template - static auto reflect(Self& self, F f) - { - return pack(f(self.op, "op"), - f(self.additional_args, "additional_args"), - f(self.ignore_modules, "ignore_modules"), - f(self.output_shape, "output_shape")); - } - - std::string name() const { return "gpu::precompile_op"; } - - shape compute_shape(std::vector inputs, const std::vector& mods) const - { - // Pop off additional args - inputs.resize(inputs.size() - additional_args); - if(output_shape.has_value()) - return output_shape.value(); - if(ignore_modules) - return op.compute_shape(inputs); - return op.compute_shape(inputs, mods); - } - - std::ptrdiff_t output_alias(const std::vector& shapes) const - { - return shapes.size() - 1; - } -}; - MIGRAPHX_REGISTER_OP(precompile_op); +MIGRAPHX_REGISTER_OP(dynamic_code_object_op); struct compiled_result { diff --git a/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp b/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp new file mode 100644 index 00000000000..7bec916b425 --- /dev/null +++ b/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp @@ -0,0 +1,193 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MIGRAPHX_GUARD_GPU_PRECOMPILE_OPS_HPP +#define MIGRAPHX_GUARD_GPU_PRECOMPILE_OPS_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace gpu { + +struct precompile_op +{ + operation op = op::identity{}; + std::size_t additional_args = 1; + bool ignore_modules = false; + std::optional output_shape = nullopt; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.op, "op"), + f(self.additional_args, "additional_args"), + f(self.ignore_modules, "ignore_modules"), + f(self.output_shape, "output_shape")); + } + + std::string name() const { return "gpu::precompile_op"; } + + shape compute_shape(std::vector inputs, const std::vector& mods) const + { + // Pop off additional args + inputs.resize(inputs.size() - additional_args); + if(output_shape.has_value()) + return output_shape.value(); + if(ignore_modules) + return op.compute_shape(inputs); + return op.compute_shape(inputs, mods); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; + +struct dynamic_code_object_op +{ + operation pre_op = precompile_op{}; + std::optional output_shape = nullopt; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.pre_op, "pre_op"), f(self.output_shape, "output_shape")); + } + + std::string name() const { return "gpu::dynamic_code_object_op"; } + + shape compute_shape(std::vector inputs, const std::vector& mods) const + { + return pre_op.compute_shape(inputs, mods); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } + argument compute(context& ctx, + const shape&, + const std::vector& args, + const std::vector& module_args, + std::function( + module_ref&, const std::unordered_map&)> run) const + { + auto static_args = std::vector{args.begin(), args.end()}; + auto output_arg = static_args.back(); + module static_mod; + if(not module_args.empty()) + { + // rewrite module without dynamic shapes + auto mod_args = std::vector{args.begin(), args.end() - 1}; + static_mod = module_args.front()->with_static_shapes(to_shapes(mod_args)); + static_mod.set_bypass(true); + + // compute output arg shape + if(output_arg.get_shape().dynamic()) + { + auto out_shapes = static_mod.compute_shapes(to_shapes(mod_args)); + auto rsp_shape = (out_shapes.size() > 1) ? shape{out_shapes} : out_shapes.front(); + static_args[static_args.size() - 1] = output_arg.reshape(rsp_shape); + } + } + else + { + if(output_arg.get_shape().dynamic()) + { + auto out_shape = pre_op.compute_shape(to_shapes(static_args)); + static_args[static_args.size() - 1] = output_arg.reshape(out_shape); + } + } + + auto temp_mod = module("temp_mod"); + std::vector args_ins; + std::vector idx(static_args.size()); + std::iota(std::begin(idx), std::end(idx), 0); + std::transform(static_args.begin(), + static_args.end(), + idx.begin(), + std::back_inserter(args_ins), + [&](const auto& arg, const auto& i) { + return temp_mod.add_parameter("temp_mod:x" + std::to_string(i), + arg.get_shape()); + }); + instruction_ref ins; + if(not module_args.empty()) + { + ins = temp_mod.add_instruction(pre_op, args_ins, {&static_mod}); + } + else + { + ins = temp_mod.add_instruction(pre_op, args_ins); + } + temp_mod.add_return({ins}); + + operation preop = any_cast(ins->get_operator()).op; + auto config = get_tuning_config(ctx, ins, preop, false); + value solution = value{}; + if(config.has_value()) + { + solution = config->solutions.front(); + } + auto compiled_op = compile(ctx, ins, preop, solution); + compiled_op.replace(temp_mod, ins); + run_passes(temp_mod, {dead_code_elimination{}}); + + // Finalize the module before execution + std::vector contexts = {migraphx::context(ctx)}; + temp_mod.finalize(contexts); + + // Build param_map based on ACTUAL parameters that exist + auto param_map = std::unordered_map{}; + for(auto i : idx) + { + param_map["temp_mod:x" + std::to_string(i)] = static_args[i]; + } + module_ref temp_mod_ref = &temp_mod; + + auto results = run(temp_mod_ref, param_map); + + if(results.size() > 1) + return results; + return results.front(); + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_GPU_PRECOMPILE_OPS_HPP diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 0e24c3cc936..9f4d2b43790 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -182,6 +182,7 @@ struct miopen_apply else if(has_compiler_for(it->name())) { check_shape(s, insert_precompile_op(it)); + check_shape(s, insert_dynamic_code_object_op(it)); } else if(attrs.contains("target")) { @@ -240,6 +241,20 @@ struct miopen_apply ins->module_inputs()); } + instruction_ref insert_dynamic_code_object_op(instruction_ref ins) const + { + assert(ins->get_operator().name() == "gpu::precompile_op"); + + if(not ins->get_shape().dynamic()) + return ins; + + return mod->replace_instruction( + ins, + make_op("gpu::dynamic_code_object_op", {{"pre_op", to_value(ins->get_operator())}}), + ins->inputs(), + ins->module_inputs()); + } + instruction_ref insert_allocation(instruction_ref ins, const shape& s) const { return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}})); @@ -334,7 +349,8 @@ struct miopen_apply static bool use_miopen_pooling(instruction_ref ins) { if(enabled(MIGRAPHX_DISABLE_MIOPEN_POOLING{}) or - not contains({shape::float_type, shape::half_type}, ins->get_shape().type())) + not contains({shape::float_type, shape::half_type}, ins->get_shape().type()) or + ins->get_shape().dynamic()) return false; auto&& op = ins->get_operator(); auto op_val = op.to_value(); @@ -355,7 +371,10 @@ struct miopen_apply { apply_map.emplace("pooling", [=](instruction_ref ins) { if(not use_miopen_pooling(ins)) - return insert_precompile_op(ins); + { + auto preop = insert_precompile_op(ins); + return insert_dynamic_code_object_op(preop); + } #if MIGRAPHX_USE_MIOPEN auto output = insert_allocation(ins, ins->get_shape()); std::vector refs = ins->inputs(); @@ -363,7 +382,8 @@ struct miopen_apply refs.push_back(output); return mod->replace_instruction(ins, make_op("gpu::pooling", op.to_value()), refs); #else - return insert_precompile_op(ins); + auto preop = insert_precompile_op(ins); + return insert_dynamic_op(preop); #endif }); } diff --git a/test/gpu/dynamic_code_object_op.cpp b/test/gpu/dynamic_code_object_op.cpp new file mode 100644 index 00000000000..ec22ded6dae --- /dev/null +++ b/test/gpu/dynamic_code_object_op.cpp @@ -0,0 +1,73 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void run_lowering(migraphx::program& p, bool offload_copy = false) +{ + auto ctx = migraphx::gpu::context{}; + migraphx::run_passes(*p.get_main_module(), {migraphx::gpu::lowering{&ctx, offload_copy}}); +} + +TEST_CASE(dynamic_code_object_op) +{ + migraphx::shape s{migraphx::shape::float_type, {{1, 3}, {2, 4}, {6, 6}}}; + migraphx::program p1; + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + + auto pw = add_pointwise(p1, "main:pointwise0", {a, b}, single_pointwise("add")); + auto pw_name = pw->name(); + auto pw_module_inputs = pw->module_inputs(); + + mm->add_return({pw}); + + run_lowering(p1); + + bool found = false; + for(auto ins : iterator_for(*p1.get_main_module())) + { + if(ins->name() == "gpu::dynamic_code_object_op") + { + found = true; + auto dyn_op = migraphx::any_cast(ins->get_operator()); + auto pre_op = migraphx::any_cast(dyn_op.pre_op); + EXPECT(pre_op.op.name() == pw_name); + EXPECT(ins->module_inputs() == pw_module_inputs); + } + } + EXPECT(found); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From 67119d8e80d1cac2a038dba702aedac56465989a Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 16 Jan 2026 09:46:36 -0800 Subject: [PATCH 05/21] add test for module helper function --- src/module.cpp | 20 +++++++++++++------- test/module_test.cpp | 20 ++++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/module.cpp b/src/module.cpp index 7ead68fa272..fa98636771c 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -718,7 +718,8 @@ std::vector module::compute_shapes(const std::vector& inputs, ins->get_shape().type_string() + " but passed " + ins_shapes[ins].type_string()); } - if(not ins->get_shape().dynamic() and options.strict_lens and ins->get_shape().lens() != ins_shapes[ins].lens()) + if(not ins->get_shape().dynamic() and options.strict_lens and + ins->get_shape().lens() != ins_shapes[ins].lens()) { MIGRAPHX_THROW(options.name + ": Mismatched lens: expected {" + to_string_range(ins->get_shape().lens()) + "} but passed {" + @@ -1470,11 +1471,12 @@ std::vector module::get_sub_modules(bool shallow) const module module::with_static_shapes(const std::vector& input_shapes) { // This routine creates a new module with the same instructions but with different input shapes. - // The sequence of instructions (operators and interconnectivity) is copied, but all input parameter shapes are replaced with new "input_shapes". + // The sequence of instructions (operators and interconnectivity) is copied, but all input + // parameter shapes are replaced with new "input_shapes". // ensure input_shapes is the same length as the parameters. auto param_names = this->get_parameter_names(); - assert(param_names.size() == input_shapes.size() && "Mismatch between input_shapes and parameter count"); + assert(param_names.size() == input_shapes.size()); // Make a mapping from the parameter names to the new shapes. std::unordered_map shape_map; @@ -1504,13 +1506,17 @@ module module::with_static_shapes(const std::vector& input_shapes) // Gather new input refs for this instruction std::vector new_args; - for(auto arg : ins->inputs()) - new_args.push_back(ins_map.at(arg)); + std::transform(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(new_args), + [&](auto arg) { return ins_map.at(arg); }); // Gather new module argument refs if present std::vector new_mod_args; - for(auto modarg : ins->module_inputs()) - new_mod_args.push_back(modarg); // Modules are *not* recreated, just reused + std::transform(ins->module_inputs().begin(), + ins->module_inputs().end(), + std::back_inserter(new_mod_args), + [&](auto modarg) { return modarg; }); instruction_ref new_ins; if(ins->name() == "@literal") diff --git a/test/module_test.cpp b/test/module_test.cpp index 87ab9019e13..65b896db5fa 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -803,6 +803,26 @@ TEST_CASE(add_params) EXPECT(m1.get_parameter("x1") == map_ins[add]); } +TEST_CASE(with_static_shapes) +{ + auto create_module = [](const std::vector& input_shapes) { + migraphx::module m; + auto x = m.add_parameter("x", input_shapes[0]); + auto y = m.add_parameter("y", input_shapes[1]); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + auto reduce_mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), add); + m.add_return({reduce_mean}); + return m; + }; + auto dyn_shape = migraphx::shape{migraphx::shape::float_type, {{1,4}, {4,8}}}; + auto dyn_mod = create_module({dyn_shape, dyn_shape}); + + auto static_shape = migraphx::shape{migraphx::shape::float_type, {2, 5}}; + auto static_mod = create_module({static_shape, static_shape}); + + EXPECT(dyn_mod.with_static_shapes({static_shape, static_shape}).sort() == static_mod.sort()); +} + TEST_CASE(linear_graph_sort) { // From 46b7adccb74d50c1fc0c227de63983a2aee5e343 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 19 Jan 2026 11:19:23 -0800 Subject: [PATCH 06/21] fix bug in pw_broadcast_pw and add test case --- src/fuse_pointwise.cpp | 19 +++++++++++++------ test/fuse_pointwise.cpp | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 563c420de82..865193d8ebc 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -348,23 +348,30 @@ struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes auto broadcast_pointwise = match::name("multibroadcast")(match::used_once(), match::args(pointwise)) .bind("broadcast"); - auto dyn_broadcast_pointwise = match::name("multibroadcast")(match::used_once(), - match::nargs(2), - match::arg(1)(pointwise)) - .bind("broadcast"); + auto dyn_broadcast_pointwise = + match::name("multibroadcast")(match::used_once(), + match::nargs(2), + match::arg(0)(pointwise), + match::arg(1)(match::any().bind("ref_ins"))) + .bind("broadcast"); return match::name("pointwise")(match::any_of[match::inputs()]( match::any_of(broadcast_pointwise, dyn_broadcast_pointwise))); } void apply(module& m, const match::matcher_result& r) const { - auto broadcast_ins = r.instructions["broadcast"]; - auto x_ins = r.instructions["x"]; + auto broadcast_ins = r.instructions["broadcast"]; + auto x_ins = r.instructions["x"]; + bool is_dyn_broadcast = contains(r.instructions, "ref_ins"); auto broadcast = broadcast_ins->get_operator(); auto x_inputs = x_ins->inputs(); std::transform(x_inputs.begin(), x_inputs.end(), x_inputs.begin(), [&](auto input) { + if(is_dyn_broadcast) + { + return m.insert_instruction(broadcast_ins, broadcast, {input, r.instructions["ref_ins"]}); + } return m.insert_instruction(broadcast_ins, broadcast, input); }); diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 08a1950fc63..54e762ae18f 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -1108,6 +1108,40 @@ TEST_CASE(add_broadcast_add) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(add_broadcast_add_dyn) +{ + migraphx::shape s1{migraphx::shape::float_type, {{2, 4}, {1, 1}}}; + migraphx::shape s2{migraphx::shape::float_type, {{2, 4}, {3, 3}}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto badd1 = mm->add_instruction(migraphx::make_op("multibroadcast"), add1, z); + auto add2 = mm->add_instruction(migraphx::make_op("add"), badd1, z); + mm->add_return({add2}); + } + run_pass(p1, {.enable_rewrite_broadcasts = true}); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto bx = mm->add_instruction(migraphx::make_op("multibroadcast"), x, z); + auto by = mm->add_instruction(migraphx::make_op("multibroadcast"), y, z); + auto fadd = + add_pointwise(p2, "main:pointwise0", {bx, by, z}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); + mm->add_return({fadd}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(rewrite_broadcast_multi_output) { migraphx::shape s{migraphx::shape::float_type, {2, 3}}; From 7689b4c1dc736d4b9c102fc92645e74bdcc67d0a Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 19 Jan 2026 11:42:51 -0800 Subject: [PATCH 07/21] fix fuse_reduce test --- test/fuse_reduce.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 589c3c07cb1..de6e3056b08 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -65,7 +65,7 @@ TEST_CASE(single) TEST_CASE(single_dyn) { - migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; + migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 8}, {}}; migraphx::program p1; { auto* mm = p1.get_main_module(); From 366f7c2b4d2e18526271c26c416c3f0a2e211bd3 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 19 Jan 2026 11:47:13 -0800 Subject: [PATCH 08/21] fix dyn shape constructor in tests --- test/fuse_pointwise.cpp | 14 +++++++------- test/fuse_reduce.cpp | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 54e762ae18f..039222907f0 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -69,7 +69,7 @@ TEST_CASE(single) TEST_CASE(single_dyn) { - migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 3}, {}}; + migraphx::shape s{migraphx::shape::float_type, {{1, 3}, {4, 8}}}; migraphx::program p1; { auto* mm = p1.get_main_module(); @@ -972,14 +972,14 @@ TEST_CASE(add_reshape_add_nonstandard) migraphx::shape s3{migraphx::shape::float_type, {3, 10, 4, 2, 2}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s1); - auto z = mm->add_parameter("z", s2); - auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), add1); - auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z); + auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z); mm->add_return({add2}); } run_pass(p1); diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index de6e3056b08..53149d914b2 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -65,7 +65,7 @@ TEST_CASE(single) TEST_CASE(single_dyn) { - migraphx::shape s{migraphx::shape::float_type, {1, 4}, {3, 8}, {}}; + migraphx::shape s{migraphx::shape::float_type, {{1, 3}, {4, 8}}}; migraphx::program p1; { auto* mm = p1.get_main_module(); From a821a235d54b1ec3a651ea550484711ac5d67202 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 19 Jan 2026 11:52:17 -0800 Subject: [PATCH 09/21] cleanup add_reduce_module func for tests --- test/include/reduce.hpp | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/include/reduce.hpp b/test/include/reduce.hpp index 6583bb900eb..2590b4caa2b 100644 --- a/test/include/reduce.hpp +++ b/test/include/reduce.hpp @@ -61,15 +61,7 @@ migraphx::module_ref add_reduce_module(migraphx::program& p, rm->set_bypass(); std::vector params; std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) { - migraphx::shape s; - if(input->get_shape().dynamic()) - { - s = input->get_shape(); - } - else - { - s = migraphx::shape{input->get_shape().type(), input->get_shape().lens()}; - } + auto s = input->get_shape().as_standard(); return rm->add_parameter("x" + std::to_string(params.size()), s); }); auto r = f(rm, params, axes); From 86d09f4b29362b2e3d0136454f7daa07dd69ed0a Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 20 Jan 2026 09:35:29 -0800 Subject: [PATCH 10/21] format --- src/fuse_pointwise.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index 865193d8ebc..f0382b58d05 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -370,7 +370,8 @@ struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes std::transform(x_inputs.begin(), x_inputs.end(), x_inputs.begin(), [&](auto input) { if(is_dyn_broadcast) { - return m.insert_instruction(broadcast_ins, broadcast, {input, r.instructions["ref_ins"]}); + return m.insert_instruction( + broadcast_ins, broadcast, {input, r.instructions["ref_ins"]}); } return m.insert_instruction(broadcast_ins, broadcast, input); }); From 6ee3e5c0877e7330e34460e1b41c011d97a73180 Mon Sep 17 00:00:00 2001 From: Shiv Date: Fri, 23 Jan 2026 12:42:44 -0800 Subject: [PATCH 11/21] add verify test --- test/verify/run_verify.cpp | 7 ++-- test/verify/test_dynamic_pointwise.cpp | 56 ++++++++++++++++++++++++++ test/verify/verify_program.hpp | 3 ++ 3 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 test/verify/test_dynamic_pointwise.cpp diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp index 6a26a46d810..a3bf0026dbe 100644 --- a/test/verify/run_verify.cpp +++ b/test/verify/run_verify.cpp @@ -205,9 +205,10 @@ void run_verify::verify(const program_info& pi) const { if(x.second.dynamic()) { - // create static shape using maximum dimensions - migraphx::shape static_shape{x.second.type(), x.second.max_lens()}; - m[x.first] = migraphx::generate_argument(static_shape, get_hash(x.first)); + auto static_shape = contains(pi.test_dims, x.first) + ? pi.test_dims.at(x.first) + : migraphx::shape{x.second.type(), x.second.max_lens()}; + m[x.first] = migraphx::generate_argument(static_shape, get_hash(x.first)); } else { diff --git a/test/verify/test_dynamic_pointwise.cpp b/test/verify/test_dynamic_pointwise.cpp new file mode 100644 index 00000000000..9d3211ab520 --- /dev/null +++ b/test/verify/test_dynamic_pointwise.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +template +struct test_dynamic_pointwise : verify_program> +{ + migraphx::program create_program() const + { + migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {8, 16}, {4, 24}}}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto relu = mm->add_instruction(migraphx::make_op("relu"), x); + auto add = mm->add_instruction(migraphx::make_op("add"), relu, y); + mm->add_return({add}); + return p; + }; + + std::unordered_map get_test_dims() const + { + return {{"x", migraphx::shape{migraphx::shape::float_type, {test_dims...}}}, + {"y", migraphx::shape{migraphx::shape::float_type, {test_dims...}}}}; + } +}; + +template struct test_dynamic_pointwise<4, 16, 24>; // maxes +template struct test_dynamic_pointwise<2, 8, 4>; // mins +template struct test_dynamic_pointwise<3, 10, 13>; // inbetween diff --git a/test/verify/verify_program.hpp b/test/verify/verify_program.hpp index 39d2392bec8..7a303d9e741 100644 --- a/test/verify/verify_program.hpp +++ b/test/verify/verify_program.hpp @@ -38,6 +38,7 @@ struct program_info std::size_t tolerance; std::function get_program; migraphx::compile_options compile_options; + std::unordered_map test_dims; }; void register_program_info(const program_info& pi); @@ -66,6 +67,7 @@ struct register_verify_program_action pi.tolerance = x.get_tolerance(); pi.get_program = [x] { return x.create_program(); }; pi.compile_options = x.get_compile_options(); + pi.test_dims = x.get_test_dims(); register_program_info(pi); } }; @@ -79,6 +81,7 @@ struct verify_program : auto_register_verify_program std::string section() const { return "general"; }; migraphx::compile_options get_compile_options() const { return migraphx::compile_options{}; }; std::size_t get_tolerance() const { return 80; }; + std::unordered_map get_test_dims() const { return {}; }; }; #endif From a5275d7976c253d44d4159202d5e0bfac82ee2ed Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 26 Jan 2026 11:02:15 -0800 Subject: [PATCH 12/21] format --- src/targets/gpu/lowering.cpp | 2 +- test/gpu/dynamic_code_object_op.cpp | 11 ++++++----- test/module_test.cpp | 15 ++++++++------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 9f4d2b43790..6653cb685a9 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -383,7 +383,7 @@ struct miopen_apply return mod->replace_instruction(ins, make_op("gpu::pooling", op.to_value()), refs); #else auto preop = insert_precompile_op(ins); - return insert_dynamic_op(preop); + return insert_dynamic_code_object_op(preop); #endif }); } diff --git a/test/gpu/dynamic_code_object_op.cpp b/test/gpu/dynamic_code_object_op.cpp index ec22ded6dae..0a4a52dbcfb 100644 --- a/test/gpu/dynamic_code_object_op.cpp +++ b/test/gpu/dynamic_code_object_op.cpp @@ -47,12 +47,12 @@ TEST_CASE(dynamic_code_object_op) auto a = mm->add_parameter("a", s); auto b = mm->add_parameter("b", s); - auto pw = add_pointwise(p1, "main:pointwise0", {a, b}, single_pointwise("add")); - auto pw_name = pw->name(); + auto pw = add_pointwise(p1, "main:pointwise0", {a, b}, single_pointwise("add")); + auto pw_name = pw->name(); auto pw_module_inputs = pw->module_inputs(); - + mm->add_return({pw}); - + run_lowering(p1); bool found = false; @@ -61,7 +61,8 @@ TEST_CASE(dynamic_code_object_op) if(ins->name() == "gpu::dynamic_code_object_op") { found = true; - auto dyn_op = migraphx::any_cast(ins->get_operator()); + auto dyn_op = + migraphx::any_cast(ins->get_operator()); auto pre_op = migraphx::any_cast(dyn_op.pre_op); EXPECT(pre_op.op.name() == pw_name); EXPECT(ins->module_inputs() == pw_module_inputs); diff --git a/test/module_test.cpp b/test/module_test.cpp index 65b896db5fa..3cdb0d6ea57 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -807,19 +807,20 @@ TEST_CASE(with_static_shapes) { auto create_module = [](const std::vector& input_shapes) { migraphx::module m; - auto x = m.add_parameter("x", input_shapes[0]); - auto y = m.add_parameter("y", input_shapes[1]); + auto x = m.add_parameter("x", input_shapes[0]); + auto y = m.add_parameter("y", input_shapes[1]); auto add = m.add_instruction(migraphx::make_op("add"), x, y); - auto reduce_mean = m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), add); + auto reduce_mean = + m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), add); m.add_return({reduce_mean}); return m; }; - auto dyn_shape = migraphx::shape{migraphx::shape::float_type, {{1,4}, {4,8}}}; - auto dyn_mod = create_module({dyn_shape, dyn_shape}); + auto dyn_shape = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 8}}}; + auto dyn_mod = create_module({dyn_shape, dyn_shape}); auto static_shape = migraphx::shape{migraphx::shape::float_type, {2, 5}}; - auto static_mod = create_module({static_shape, static_shape}); - + auto static_mod = create_module({static_shape, static_shape}); + EXPECT(dyn_mod.with_static_shapes({static_shape, static_shape}).sort() == static_mod.sort()); } From 1696bf904bb68d84ca2bffb85336330521a78235 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 26 Jan 2026 12:03:57 -0800 Subject: [PATCH 13/21] tidy --- src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp b/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp index 7bec916b425..d9a7a214dab 100644 --- a/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp +++ b/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp @@ -90,7 +90,7 @@ struct dynamic_code_object_op std::string name() const { return "gpu::dynamic_code_object_op"; } - shape compute_shape(std::vector inputs, const std::vector& mods) const + shape compute_shape(const std::vector& inputs, const std::vector& mods) const { return pre_op.compute_shape(inputs, mods); } @@ -103,8 +103,8 @@ struct dynamic_code_object_op const shape&, const std::vector& args, const std::vector& module_args, - std::function( - module_ref&, const std::unordered_map&)> run) const + const std::function( + module_ref&, const std::unordered_map&)>& run) const { auto static_args = std::vector{args.begin(), args.end()}; auto output_arg = static_args.back(); From eed8c4210cf348a6eca978336603c53d41d08795 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 26 Jan 2026 12:06:20 -0800 Subject: [PATCH 14/21] license --- src/targets/gpu/lowering.cpp | 2 +- test/module_test.cpp | 2 +- test/verify/run_verify.cpp | 2 +- test/verify/verify_program.hpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 6653cb685a9..8f04e95c40b 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/module_test.cpp b/test/module_test.cpp index 3cdb0d6ea57..0c70e8c814d 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp index a3bf0026dbe..6197af4c6c8 100644 --- a/test/verify/run_verify.cpp +++ b/test/verify/run_verify.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/verify/verify_program.hpp b/test/verify/verify_program.hpp index 7a303d9e741..20b42e1cfef 100644 --- a/test/verify/verify_program.hpp +++ b/test/verify/verify_program.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 2e763fafaa8ca12d6b90e38e6032c1bfca76c2a3 Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 26 Jan 2026 14:38:37 -0800 Subject: [PATCH 15/21] tidy + format --- src/targets/gpu/lowering.cpp | 2 +- test/verify/test_dynamic_pointwise.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 8f04e95c40b..a2fc6892e4c 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -381,7 +381,7 @@ struct miopen_apply auto&& op = ins->get_operator(); refs.push_back(output); return mod->replace_instruction(ins, make_op("gpu::pooling", op.to_value()), refs); -#else +#else auto preop = insert_precompile_op(ins); return insert_dynamic_code_object_op(preop); #endif diff --git a/test/verify/test_dynamic_pointwise.cpp b/test/verify/test_dynamic_pointwise.cpp index 9d3211ab520..7f744f4fc70 100644 --- a/test/verify/test_dynamic_pointwise.cpp +++ b/test/verify/test_dynamic_pointwise.cpp @@ -28,8 +28,8 @@ #include #include -template -struct test_dynamic_pointwise : verify_program> +template +struct test_dynamic_pointwise : verify_program> { migraphx::program create_program() const { @@ -46,8 +46,8 @@ struct test_dynamic_pointwise : verify_program get_test_dims() const { - return {{"x", migraphx::shape{migraphx::shape::float_type, {test_dims...}}}, - {"y", migraphx::shape{migraphx::shape::float_type, {test_dims...}}}}; + return {{"x", migraphx::shape{migraphx::shape::float_type, {TestDims...}}}, + {"y", migraphx::shape{migraphx::shape::float_type, {TestDims...}}}}; } }; From 3a98e069863003aa1a17889f3cd5e68353505497 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 27 Jan 2026 10:14:29 -0800 Subject: [PATCH 16/21] disable dynamic verify tests for cpu --- test/verify/main.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/verify/main.cpp b/test/verify/main.cpp index dde235e3094..ade41655780 100644 --- a/test/verify/main.cpp +++ b/test/verify/main.cpp @@ -142,7 +142,10 @@ int main(int argc, const char* argv[]) "test_bit_cast", "test_bit_cast", "test_bit_cast", - "test_bit_cast"}); + "test_bit_cast" + "test_dynamic_pointwise<4, 16, 24>", + "test_dynamic_pointwise<2, 8, 4>", + "test_dynamic_pointwise<3, 10, 13>"}); rv.disable_test_for("gpu", { // These passes on MI300 but fails on others, same issue as CPU. From 7a6e28b01e301e2c22fad6e5f32343998cd706ca Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 27 Jan 2026 15:50:40 -0800 Subject: [PATCH 17/21] review updates and cleanup --- src/include/migraphx/module.hpp | 2 +- src/module.cpp | 55 +---- src/targets/gpu/compile_ops.cpp | 154 +++++++++++++- .../include/migraphx/gpu/precompile_ops.hpp | 193 ------------------ test/gpu/dynamic_code_object_op.cpp | 6 - test/module_test.cpp | 4 +- test/verify/main.cpp | 2 +- 7 files changed, 165 insertions(+), 251 deletions(-) delete mode 100644 src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index 04b2dea9221..9741ad770e4 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -327,7 +327,7 @@ struct MIGRAPHX_EXPORT module /* Creates a new module with the same instructions but with different input parameter shapes. Returns the new module by value without modifying the original. */ - module with_static_shapes(const std::vector& input_shapes); + module with_static_shapes(const std::unordered_map& input_shapes); /* sorts the module in topological order aka reverse-post order (RPO) DFS order it takes last instruction or @return as the root and walks back the graph and moves inputs diff --git a/src/module.cpp b/src/module.cpp index fa98636771c..6e72f0d8ffe 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -1468,7 +1468,7 @@ std::vector module::get_sub_modules(bool shallow) const return vec_modules; } -module module::with_static_shapes(const std::vector& input_shapes) +module module::with_static_shapes(const std::unordered_map& input_shapes) { // This routine creates a new module with the same instructions but with different input shapes. // The sequence of instructions (operators and interconnectivity) is copied, but all input @@ -1478,64 +1478,23 @@ module module::with_static_shapes(const std::vector& input_shapes) auto param_names = this->get_parameter_names(); assert(param_names.size() == input_shapes.size()); - // Make a mapping from the parameter names to the new shapes. - std::unordered_map shape_map; - for(std::size_t i = 0; i < param_names.size(); ++i) - shape_map[param_names[i]] = input_shapes[i]; - module new_mod; - std::unordered_map ins_map; - // First, create parameters with new shapes in new_mod and fill ins_map for params + // create parameters with new shapes in new_mod and fill ins_map for params for(auto ins : iterator_for(*this)) { if(ins->name() == "@param") { auto pname = any_cast(ins->get_operator()).parameter; - assert(shape_map.count(pname) > 0); - ins_map[ins] = new_mod.add_parameter(pname, shape_map.at(pname)); + assert(input_shapes.count(pname) > 0); + ins_map[ins] = new_mod.add_parameter(pname, input_shapes.at(pname)); } } - // Copy remaining instructions (except parameters) in order - for(auto ins : iterator_for(*this)) - { - if(ins->name() == "@param") - continue; - - // Gather new input refs for this instruction - std::vector new_args; - std::transform(ins->inputs().begin(), - ins->inputs().end(), - std::back_inserter(new_args), - [&](auto arg) { return ins_map.at(arg); }); - - // Gather new module argument refs if present - std::vector new_mod_args; - std::transform(ins->module_inputs().begin(), - ins->module_inputs().end(), - std::back_inserter(new_mod_args), - [&](auto modarg) { return modarg; }); - - instruction_ref new_ins; - if(ins->name() == "@literal") - { - new_ins = new_mod.add_literal(ins->get_literal()); - } - else if(ins->name() == "@return") - { - new_ins = new_mod.add_return(new_args); - } - else - { - if(new_mod_args.empty()) - new_ins = new_mod.add_instruction(ins->get_operator(), new_args); - else - new_ins = new_mod.add_instruction(ins->get_operator(), new_args, new_mod_args); - } - ins_map[ins] = new_ins; - } + // Copy remaining instructions in order + auto ret = new_mod.insert_instructions(new_mod.end(), this, &ins_map); + new_mod.add_return(ret); return new_mod; } diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index cc612b4770a..9edd3c2ef7a 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -37,7 +37,6 @@ #include #include #include -#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -46,7 +45,160 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_BENCHMARKING); +struct precompile_op +{ + operation op = op::identity{}; + std::size_t additional_args = 1; + bool ignore_modules = false; + std::optional output_shape = nullopt; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.op, "op"), + f(self.additional_args, "additional_args"), + f(self.ignore_modules, "ignore_modules"), + f(self.output_shape, "output_shape")); + } + + std::string name() const { return "gpu::precompile_op"; } + + shape compute_shape(std::vector inputs, const std::vector& mods) const + { + // Pop off additional args + inputs.resize(inputs.size() - additional_args); + if(output_shape.has_value()) + return output_shape.value(); + if(ignore_modules) + return op.compute_shape(inputs); + return op.compute_shape(inputs, mods); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } +}; MIGRAPHX_REGISTER_OP(precompile_op); + +struct dynamic_code_object_op +{ + operation pre_op = precompile_op{}; + std::optional output_shape = nullopt; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.pre_op, "pre_op"), f(self.output_shape, "output_shape")); + } + + std::string name() const { return "gpu::dynamic_code_object_op"; } + + shape compute_shape(const std::vector& inputs, const std::vector& mods) const + { + return pre_op.compute_shape(inputs, mods); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } + argument compute(context& ctx, + const shape&, + const std::vector& args, + const std::vector& module_args, + const std::function( + module_ref&, const std::unordered_map&)>& run) const + { + auto static_args = std::vector{args.begin(), args.end()}; + auto output_arg = static_args.back(); + module static_mod; + if(not module_args.empty()) + { + // rewrite module without dynamic shapes + auto pnames = module_args.front()->get_parameter_names(); + std::unordered_map mod_arg_shapes; + std::transform(pnames.begin(), + pnames.end(), + args.begin(), + std::inserter(mod_arg_shapes, mod_arg_shapes.end()), + [&](const auto& name, const auto& arg) { + return std::make_pair(name, arg.get_shape()); + }); + static_mod = module_args.front()->with_static_shapes(mod_arg_shapes); + static_mod.set_bypass(true); + + // compute output arg shape + if(output_arg.get_shape().dynamic()) + { + auto mod_args = std::vector{args.begin(), args.end() - 1}; + auto out_shapes = static_mod.compute_shapes(to_shapes(mod_args)); + auto rsp_shape = (out_shapes.size() > 1) ? shape{out_shapes} : out_shapes.front(); + static_args[static_args.size() - 1] = output_arg.reshape(rsp_shape); + } + } + else + { + if(output_arg.get_shape().dynamic()) + { + auto out_shape = pre_op.compute_shape(to_shapes(static_args)); + static_args[static_args.size() - 1] = output_arg.reshape(out_shape); + } + } + + auto temp_mod = module("temp_mod"); + std::vector args_ins; + std::vector idx(static_args.size()); + std::iota(std::begin(idx), std::end(idx), 0); + std::transform(static_args.begin(), + static_args.end(), + idx.begin(), + std::back_inserter(args_ins), + [&](const auto& arg, const auto& i) { + return temp_mod.add_parameter("temp_mod:x" + std::to_string(i), + arg.get_shape()); + }); + instruction_ref ins; + if(not module_args.empty()) + { + ins = temp_mod.add_instruction(pre_op, args_ins, {&static_mod}); + } + else + { + ins = temp_mod.add_instruction(pre_op, args_ins); + } + temp_mod.add_return({ins}); + + operation preop = any_cast(ins->get_operator()).op; + auto config = get_tuning_config(ctx, ins, preop, false); + value solution = value{}; + if(config.has_value()) + { + solution = config->solutions.front(); + } + auto compiled_op = compile(ctx, ins, preop, solution); + compiled_op.replace(temp_mod, ins); + run_passes(temp_mod, {dead_code_elimination{}}); + + // Finalize the module before execution + std::vector contexts = {migraphx::context(ctx)}; + temp_mod.finalize(contexts); + + // Build param_map based on ACTUAL parameters that exist + auto param_map = std::unordered_map{}; + for(auto i : idx) + { + param_map["temp_mod:x" + std::to_string(i)] = static_args[i]; + } + module_ref temp_mod_ref = &temp_mod; + + auto results = run(temp_mod_ref, param_map); + + if(results.size() > 1) + return results; + return results.front(); + } +}; MIGRAPHX_REGISTER_OP(dynamic_code_object_op); struct compiled_result diff --git a/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp b/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp deleted file mode 100644 index d9a7a214dab..00000000000 --- a/src/targets/gpu/include/migraphx/gpu/precompile_ops.hpp +++ /dev/null @@ -1,193 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#ifndef MIGRAPHX_GUARD_GPU_PRECOMPILE_OPS_HPP -#define MIGRAPHX_GUARD_GPU_PRECOMPILE_OPS_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { -namespace gpu { - -struct precompile_op -{ - operation op = op::identity{}; - std::size_t additional_args = 1; - bool ignore_modules = false; - std::optional output_shape = nullopt; - - template - static auto reflect(Self& self, F f) - { - return pack(f(self.op, "op"), - f(self.additional_args, "additional_args"), - f(self.ignore_modules, "ignore_modules"), - f(self.output_shape, "output_shape")); - } - - std::string name() const { return "gpu::precompile_op"; } - - shape compute_shape(std::vector inputs, const std::vector& mods) const - { - // Pop off additional args - inputs.resize(inputs.size() - additional_args); - if(output_shape.has_value()) - return output_shape.value(); - if(ignore_modules) - return op.compute_shape(inputs); - return op.compute_shape(inputs, mods); - } - - std::ptrdiff_t output_alias(const std::vector& shapes) const - { - return shapes.size() - 1; - } -}; - -struct dynamic_code_object_op -{ - operation pre_op = precompile_op{}; - std::optional output_shape = nullopt; - - template - static auto reflect(Self& self, F f) - { - return pack(f(self.pre_op, "pre_op"), f(self.output_shape, "output_shape")); - } - - std::string name() const { return "gpu::dynamic_code_object_op"; } - - shape compute_shape(const std::vector& inputs, const std::vector& mods) const - { - return pre_op.compute_shape(inputs, mods); - } - - std::ptrdiff_t output_alias(const std::vector& shapes) const - { - return shapes.size() - 1; - } - argument compute(context& ctx, - const shape&, - const std::vector& args, - const std::vector& module_args, - const std::function( - module_ref&, const std::unordered_map&)>& run) const - { - auto static_args = std::vector{args.begin(), args.end()}; - auto output_arg = static_args.back(); - module static_mod; - if(not module_args.empty()) - { - // rewrite module without dynamic shapes - auto mod_args = std::vector{args.begin(), args.end() - 1}; - static_mod = module_args.front()->with_static_shapes(to_shapes(mod_args)); - static_mod.set_bypass(true); - - // compute output arg shape - if(output_arg.get_shape().dynamic()) - { - auto out_shapes = static_mod.compute_shapes(to_shapes(mod_args)); - auto rsp_shape = (out_shapes.size() > 1) ? shape{out_shapes} : out_shapes.front(); - static_args[static_args.size() - 1] = output_arg.reshape(rsp_shape); - } - } - else - { - if(output_arg.get_shape().dynamic()) - { - auto out_shape = pre_op.compute_shape(to_shapes(static_args)); - static_args[static_args.size() - 1] = output_arg.reshape(out_shape); - } - } - - auto temp_mod = module("temp_mod"); - std::vector args_ins; - std::vector idx(static_args.size()); - std::iota(std::begin(idx), std::end(idx), 0); - std::transform(static_args.begin(), - static_args.end(), - idx.begin(), - std::back_inserter(args_ins), - [&](const auto& arg, const auto& i) { - return temp_mod.add_parameter("temp_mod:x" + std::to_string(i), - arg.get_shape()); - }); - instruction_ref ins; - if(not module_args.empty()) - { - ins = temp_mod.add_instruction(pre_op, args_ins, {&static_mod}); - } - else - { - ins = temp_mod.add_instruction(pre_op, args_ins); - } - temp_mod.add_return({ins}); - - operation preop = any_cast(ins->get_operator()).op; - auto config = get_tuning_config(ctx, ins, preop, false); - value solution = value{}; - if(config.has_value()) - { - solution = config->solutions.front(); - } - auto compiled_op = compile(ctx, ins, preop, solution); - compiled_op.replace(temp_mod, ins); - run_passes(temp_mod, {dead_code_elimination{}}); - - // Finalize the module before execution - std::vector contexts = {migraphx::context(ctx)}; - temp_mod.finalize(contexts); - - // Build param_map based on ACTUAL parameters that exist - auto param_map = std::unordered_map{}; - for(auto i : idx) - { - param_map["temp_mod:x" + std::to_string(i)] = static_args[i]; - } - module_ref temp_mod_ref = &temp_mod; - - auto results = run(temp_mod_ref, param_map); - - if(results.size() > 1) - return results; - return results.front(); - } -}; - -} // namespace gpu -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx -#endif // MIGRAPHX_GUARD_GPU_PRECOMPILE_OPS_HPP diff --git a/test/gpu/dynamic_code_object_op.cpp b/test/gpu/dynamic_code_object_op.cpp index 0a4a52dbcfb..f9711b4d704 100644 --- a/test/gpu/dynamic_code_object_op.cpp +++ b/test/gpu/dynamic_code_object_op.cpp @@ -23,7 +23,6 @@ */ #include -#include #include #include #include @@ -48,7 +47,6 @@ TEST_CASE(dynamic_code_object_op) auto b = mm->add_parameter("b", s); auto pw = add_pointwise(p1, "main:pointwise0", {a, b}, single_pointwise("add")); - auto pw_name = pw->name(); auto pw_module_inputs = pw->module_inputs(); mm->add_return({pw}); @@ -61,10 +59,6 @@ TEST_CASE(dynamic_code_object_op) if(ins->name() == "gpu::dynamic_code_object_op") { found = true; - auto dyn_op = - migraphx::any_cast(ins->get_operator()); - auto pre_op = migraphx::any_cast(dyn_op.pre_op); - EXPECT(pre_op.op.name() == pw_name); EXPECT(ins->module_inputs() == pw_module_inputs); } } diff --git a/test/module_test.cpp b/test/module_test.cpp index 0c70e8c814d..9ce615cf4bc 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -820,8 +820,10 @@ TEST_CASE(with_static_shapes) auto static_shape = migraphx::shape{migraphx::shape::float_type, {2, 5}}; auto static_mod = create_module({static_shape, static_shape}); + std::unordered_map static_arg_shapes{{"x", static_shape}, + {"y", static_shape}}; - EXPECT(dyn_mod.with_static_shapes({static_shape, static_shape}).sort() == static_mod.sort()); + EXPECT(dyn_mod.with_static_shapes(static_arg_shapes).sort() == static_mod.sort()); } TEST_CASE(linear_graph_sort) diff --git a/test/verify/main.cpp b/test/verify/main.cpp index ade41655780..f641546e432 100644 --- a/test/verify/main.cpp +++ b/test/verify/main.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From c8c69a03b051a0804177f04445923ede42ad3cac Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 27 Jan 2026 15:52:11 -0800 Subject: [PATCH 18/21] tidy --- test/verify/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/main.cpp b/test/verify/main.cpp index f641546e432..0db47d92b24 100644 --- a/test/verify/main.cpp +++ b/test/verify/main.cpp @@ -142,7 +142,7 @@ int main(int argc, const char* argv[]) "test_bit_cast", "test_bit_cast", "test_bit_cast", - "test_bit_cast" + "test_bit_cast", "test_dynamic_pointwise<4, 16, 24>", "test_dynamic_pointwise<2, 8, 4>", "test_dynamic_pointwise<3, 10, 13>"}); From 0fd36f6201a1eec54b18e08dad3fbb7253f48c80 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 28 Jan 2026 15:53:14 -0800 Subject: [PATCH 19/21] clean up code object op and address review comments --- src/targets/gpu/compile_ops.cpp | 101 +++++++++++++++++++------------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 9edd3c2ef7a..b3d79d061ca 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -83,13 +83,15 @@ MIGRAPHX_REGISTER_OP(precompile_op); struct dynamic_code_object_op { - operation pre_op = precompile_op{}; - std::optional output_shape = nullopt; + operation pre_op = precompile_op{}; + mutable std::shared_ptr cache_mod; + mutable std::shared_ptr> cache_input_shapes; + mutable std::shared_ptr cache_static_output_shape; template static auto reflect(Self& self, F f) { - return pack(f(self.pre_op, "pre_op"), f(self.output_shape, "output_shape")); + return pack(f(self.pre_op, "pre_op")); } std::string name() const { return "gpu::dynamic_code_object_op"; } @@ -103,6 +105,19 @@ struct dynamic_code_object_op { return shapes.size() - 1; } + std::unordered_map build_param_map(const std::vector& args, + module_ref mod) const + { + auto pnames = mod->get_parameter_names(); + assert(pnames.size() == args.size()); + std::unordered_map param_map; + std::transform(pnames.begin(), + pnames.end(), + args.begin(), + std::inserter(param_map, param_map.end()), + [](const auto& name, const auto& arg) { return std::make_pair(name, arg); }); + return param_map; + } argument compute(context& ctx, const shape&, const std::vector& args, @@ -112,10 +127,28 @@ struct dynamic_code_object_op { auto static_args = std::vector{args.begin(), args.end()}; auto output_arg = static_args.back(); - module static_mod; + + if(cache_mod and *cache_input_shapes == to_shapes(args)) + { + static_args[static_args.size() - 1] = output_arg.reshape(*cache_static_output_shape); + auto mod = cache_mod.get(); + auto param_map = build_param_map(static_args, mod); + auto results = run(mod, param_map); + if(results.size() > 1) + return results; + return results.front(); + } + + if(output_arg.get_shape().dynamic()) + { + auto out_shape = pre_op.compute_shape(to_shapes(static_args), module_args); + static_args[static_args.size() - 1] = output_arg.reshape(out_shape); + } + + // Rewrite submodule without dynamic shapes to be used as the IR for compilation + module static_submod; if(not module_args.empty()) { - // rewrite module without dynamic shapes auto pnames = module_args.front()->get_parameter_names(); std::unordered_map mod_arg_shapes; std::transform(pnames.begin(), @@ -125,28 +158,13 @@ struct dynamic_code_object_op [&](const auto& name, const auto& arg) { return std::make_pair(name, arg.get_shape()); }); - static_mod = module_args.front()->with_static_shapes(mod_arg_shapes); - static_mod.set_bypass(true); - - // compute output arg shape - if(output_arg.get_shape().dynamic()) - { - auto mod_args = std::vector{args.begin(), args.end() - 1}; - auto out_shapes = static_mod.compute_shapes(to_shapes(mod_args)); - auto rsp_shape = (out_shapes.size() > 1) ? shape{out_shapes} : out_shapes.front(); - static_args[static_args.size() - 1] = output_arg.reshape(rsp_shape); - } - } - else - { - if(output_arg.get_shape().dynamic()) - { - auto out_shape = pre_op.compute_shape(to_shapes(static_args)); - static_args[static_args.size() - 1] = output_arg.reshape(out_shape); - } + static_submod = module_args.front()->with_static_shapes(mod_arg_shapes); + static_submod.set_bypass(true); } - auto temp_mod = module("temp_mod"); + // Create runtime module which will be compiled and cached + auto name = "runtime_mod:" + module_args.front()->name(); + auto runtime_mod = module(name); std::vector args_ins; std::vector idx(static_args.size()); std::iota(std::begin(idx), std::end(idx), 0); @@ -155,20 +173,21 @@ struct dynamic_code_object_op idx.begin(), std::back_inserter(args_ins), [&](const auto& arg, const auto& i) { - return temp_mod.add_parameter("temp_mod:x" + std::to_string(i), - arg.get_shape()); + return runtime_mod.add_parameter(name + ":x" + std::to_string(i), + arg.get_shape()); }); instruction_ref ins; if(not module_args.empty()) { - ins = temp_mod.add_instruction(pre_op, args_ins, {&static_mod}); + ins = runtime_mod.add_instruction(pre_op, args_ins, {&static_submod}); } else { - ins = temp_mod.add_instruction(pre_op, args_ins); + ins = runtime_mod.add_instruction(pre_op, args_ins); } - temp_mod.add_return({ins}); + runtime_mod.add_return({ins}); + // Compile ins and replace with a compiled code object op operation preop = any_cast(ins->get_operator()).op; auto config = get_tuning_config(ctx, ins, preop, false); value solution = value{}; @@ -177,22 +196,24 @@ struct dynamic_code_object_op solution = config->solutions.front(); } auto compiled_op = compile(ctx, ins, preop, solution); - compiled_op.replace(temp_mod, ins); - run_passes(temp_mod, {dead_code_elimination{}}); + compiled_op.replace(runtime_mod, ins); + run_passes(runtime_mod, {dead_code_elimination{}}); // Finalize the module before execution std::vector contexts = {migraphx::context(ctx)}; - temp_mod.finalize(contexts); + runtime_mod.finalize(contexts); + + // Update cache + // TODO: This will be updated to store compiled code objects for all encountered shapes + cache_mod = std::make_shared(runtime_mod); + cache_input_shapes = std::make_shared>(to_shapes(args)); + cache_static_output_shape = std::make_shared(static_args.back().get_shape()); // Build param_map based on ACTUAL parameters that exist - auto param_map = std::unordered_map{}; - for(auto i : idx) - { - param_map["temp_mod:x" + std::to_string(i)] = static_args[i]; - } - module_ref temp_mod_ref = &temp_mod; + module_ref runtime_mod_ref = &runtime_mod; + auto param_map = build_param_map(static_args, runtime_mod_ref); - auto results = run(temp_mod_ref, param_map); + auto results = run(runtime_mod_ref, param_map); if(results.size() > 1) return results; From f238dd4fc0ced2546e70209063a04442c509b8ae Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 28 Jan 2026 17:45:34 -0800 Subject: [PATCH 20/21] cppcheck fix --- src/targets/gpu/compile_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index b3d79d061ca..72ef45223cf 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -106,7 +106,7 @@ struct dynamic_code_object_op return shapes.size() - 1; } std::unordered_map build_param_map(const std::vector& args, - module_ref mod) const + const_module_ref mod) const { auto pnames = mod->get_parameter_names(); assert(pnames.size() == args.size()); From 121d9d11e4177784f0dbc12362fd1bf9aeee2c17 Mon Sep 17 00:00:00 2001 From: Shiv Date: Wed, 28 Jan 2026 17:58:33 -0800 Subject: [PATCH 21/21] tidy --- src/targets/gpu/compile_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 72ef45223cf..3ff730455d1 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -131,7 +131,7 @@ struct dynamic_code_object_op if(cache_mod and *cache_input_shapes == to_shapes(args)) { static_args[static_args.size() - 1] = output_arg.reshape(*cache_static_output_shape); - auto mod = cache_mod.get(); + auto* mod = cache_mod.get(); auto param_map = build_param_map(static_args, mod); auto results = run(mod, param_map); if(results.size() > 1)