diff --git a/src/include/migraphx/module.hpp b/src/include/migraphx/module.hpp index d68b2683e65..9741ad770e4 100644 --- a/src/include/migraphx/module.hpp +++ b/src/include/migraphx/module.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -323,6 +323,12 @@ struct MIGRAPHX_EXPORT module void annotate(std::ostream& os, std::function a) const; std::vector get_sub_modules(bool shallow = false) const; + + /* Creates a new module with the same instructions but with different input parameter shapes. + Returns the new module by value without modifying the original. + */ + module with_static_shapes(const std::unordered_map& input_shapes); + /* sorts the module in topological order aka reverse-post order (RPO) DFS order it takes last instruction or @return as the root and walks back the graph and moves inputs of the each instruction such that it appears before the instruction itself. diff --git a/src/module.cpp b/src/module.cpp index 92f09db3eea..6e72f0d8ffe 100644 --- a/src/module.cpp +++ b/src/module.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -1467,6 +1468,37 @@ std::vector module::get_sub_modules(bool shallow) const return vec_modules; } +module module::with_static_shapes(const std::unordered_map& input_shapes) +{ + // This routine creates a new module with the same instructions but with different input shapes. + // The sequence of instructions (operators and interconnectivity) is copied, but all input + // parameter shapes are replaced with new "input_shapes". + + // ensure input_shapes is the same length as the parameters. + auto param_names = this->get_parameter_names(); + assert(param_names.size() == input_shapes.size()); + + module new_mod; + std::unordered_map ins_map; + + // create parameters with new shapes in new_mod and fill ins_map for params + for(auto ins : iterator_for(*this)) + { + if(ins->name() == "@param") + { + auto pname = any_cast(ins->get_operator()).parameter; + assert(input_shapes.count(pname) > 0); + ins_map[ins] = new_mod.add_parameter(pname, input_shapes.at(pname)); + } + } + + // Copy remaining instructions in order + auto ret = new_mod.insert_instructions(new_mod.end(), this, &ins_map); + new_mod.add_return(ret); + + return new_mod; +} + module& module::sort() { if(this->begin() == this->end()) diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index acf8eda65ee..3ff730455d1 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -79,9 +79,149 @@ struct precompile_op return shapes.size() - 1; } }; - MIGRAPHX_REGISTER_OP(precompile_op); +struct dynamic_code_object_op +{ + operation pre_op = precompile_op{}; + mutable std::shared_ptr cache_mod; + mutable std::shared_ptr> cache_input_shapes; + mutable std::shared_ptr cache_static_output_shape; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.pre_op, "pre_op")); + } + + std::string name() const { return "gpu::dynamic_code_object_op"; } + + shape compute_shape(const std::vector& inputs, const std::vector& mods) const + { + return pre_op.compute_shape(inputs, mods); + } + + std::ptrdiff_t output_alias(const std::vector& shapes) const + { + return shapes.size() - 1; + } + std::unordered_map build_param_map(const std::vector& args, + const_module_ref mod) const + { + auto pnames = mod->get_parameter_names(); + assert(pnames.size() == args.size()); + std::unordered_map param_map; + std::transform(pnames.begin(), + pnames.end(), + args.begin(), + std::inserter(param_map, param_map.end()), + [](const auto& name, const auto& arg) { return std::make_pair(name, arg); }); + return param_map; + } + argument compute(context& ctx, + const shape&, + const std::vector& args, + const std::vector& module_args, + const std::function( + module_ref&, const std::unordered_map&)>& run) const + { + auto static_args = std::vector{args.begin(), args.end()}; + auto output_arg = static_args.back(); + + if(cache_mod and *cache_input_shapes == to_shapes(args)) + { + static_args[static_args.size() - 1] = output_arg.reshape(*cache_static_output_shape); + auto* mod = cache_mod.get(); + auto param_map = build_param_map(static_args, mod); + auto results = run(mod, param_map); + if(results.size() > 1) + return results; + return results.front(); + } + + if(output_arg.get_shape().dynamic()) + { + auto out_shape = pre_op.compute_shape(to_shapes(static_args), module_args); + static_args[static_args.size() - 1] = output_arg.reshape(out_shape); + } + + // Rewrite submodule without dynamic shapes to be used as the IR for compilation + module static_submod; + if(not module_args.empty()) + { + auto pnames = module_args.front()->get_parameter_names(); + std::unordered_map mod_arg_shapes; + std::transform(pnames.begin(), + pnames.end(), + args.begin(), + std::inserter(mod_arg_shapes, mod_arg_shapes.end()), + [&](const auto& name, const auto& arg) { + return std::make_pair(name, arg.get_shape()); + }); + static_submod = module_args.front()->with_static_shapes(mod_arg_shapes); + static_submod.set_bypass(true); + } + + // Create runtime module which will be compiled and cached + auto name = "runtime_mod:" + module_args.front()->name(); + auto runtime_mod = module(name); + std::vector args_ins; + std::vector idx(static_args.size()); + std::iota(std::begin(idx), std::end(idx), 0); + std::transform(static_args.begin(), + static_args.end(), + idx.begin(), + std::back_inserter(args_ins), + [&](const auto& arg, const auto& i) { + return runtime_mod.add_parameter(name + ":x" + std::to_string(i), + arg.get_shape()); + }); + instruction_ref ins; + if(not module_args.empty()) + { + ins = runtime_mod.add_instruction(pre_op, args_ins, {&static_submod}); + } + else + { + ins = runtime_mod.add_instruction(pre_op, args_ins); + } + runtime_mod.add_return({ins}); + + // Compile ins and replace with a compiled code object op + operation preop = any_cast(ins->get_operator()).op; + auto config = get_tuning_config(ctx, ins, preop, false); + value solution = value{}; + if(config.has_value()) + { + solution = config->solutions.front(); + } + auto compiled_op = compile(ctx, ins, preop, solution); + compiled_op.replace(runtime_mod, ins); + run_passes(runtime_mod, {dead_code_elimination{}}); + + // Finalize the module before execution + std::vector contexts = {migraphx::context(ctx)}; + runtime_mod.finalize(contexts); + + // Update cache + // TODO: This will be updated to store compiled code objects for all encountered shapes + cache_mod = std::make_shared(runtime_mod); + cache_input_shapes = std::make_shared>(to_shapes(args)); + cache_static_output_shape = std::make_shared(static_args.back().get_shape()); + + // Build param_map based on ACTUAL parameters that exist + module_ref runtime_mod_ref = &runtime_mod; + auto param_map = build_param_map(static_args, runtime_mod_ref); + + auto results = run(runtime_mod_ref, param_map); + + if(results.size() > 1) + return results; + return results.front(); + } +}; +MIGRAPHX_REGISTER_OP(dynamic_code_object_op); + struct compiled_result { compiler_replace replace; diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 0e24c3cc936..a2fc6892e4c 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -182,6 +182,7 @@ struct miopen_apply else if(has_compiler_for(it->name())) { check_shape(s, insert_precompile_op(it)); + check_shape(s, insert_dynamic_code_object_op(it)); } else if(attrs.contains("target")) { @@ -240,6 +241,20 @@ struct miopen_apply ins->module_inputs()); } + instruction_ref insert_dynamic_code_object_op(instruction_ref ins) const + { + assert(ins->get_operator().name() == "gpu::precompile_op"); + + if(not ins->get_shape().dynamic()) + return ins; + + return mod->replace_instruction( + ins, + make_op("gpu::dynamic_code_object_op", {{"pre_op", to_value(ins->get_operator())}}), + ins->inputs(), + ins->module_inputs()); + } + instruction_ref insert_allocation(instruction_ref ins, const shape& s) const { return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}})); @@ -334,7 +349,8 @@ struct miopen_apply static bool use_miopen_pooling(instruction_ref ins) { if(enabled(MIGRAPHX_DISABLE_MIOPEN_POOLING{}) or - not contains({shape::float_type, shape::half_type}, ins->get_shape().type())) + not contains({shape::float_type, shape::half_type}, ins->get_shape().type()) or + ins->get_shape().dynamic()) return false; auto&& op = ins->get_operator(); auto op_val = op.to_value(); @@ -355,15 +371,19 @@ struct miopen_apply { apply_map.emplace("pooling", [=](instruction_ref ins) { if(not use_miopen_pooling(ins)) - return insert_precompile_op(ins); + { + auto preop = insert_precompile_op(ins); + return insert_dynamic_code_object_op(preop); + } #if MIGRAPHX_USE_MIOPEN auto output = insert_allocation(ins, ins->get_shape()); std::vector refs = ins->inputs(); auto&& op = ins->get_operator(); refs.push_back(output); return mod->replace_instruction(ins, make_op("gpu::pooling", op.to_value()), refs); -#else - return insert_precompile_op(ins); +#else + auto preop = insert_precompile_op(ins); + return insert_dynamic_code_object_op(preop); #endif }); } diff --git a/test/gpu/dynamic_code_object_op.cpp b/test/gpu/dynamic_code_object_op.cpp new file mode 100644 index 00000000000..f9711b4d704 --- /dev/null +++ b/test/gpu/dynamic_code_object_op.cpp @@ -0,0 +1,68 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void run_lowering(migraphx::program& p, bool offload_copy = false) +{ + auto ctx = migraphx::gpu::context{}; + migraphx::run_passes(*p.get_main_module(), {migraphx::gpu::lowering{&ctx, offload_copy}}); +} + +TEST_CASE(dynamic_code_object_op) +{ + migraphx::shape s{migraphx::shape::float_type, {{1, 3}, {2, 4}, {6, 6}}}; + migraphx::program p1; + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("a", s); + auto b = mm->add_parameter("b", s); + + auto pw = add_pointwise(p1, "main:pointwise0", {a, b}, single_pointwise("add")); + auto pw_module_inputs = pw->module_inputs(); + + mm->add_return({pw}); + + run_lowering(p1); + + bool found = false; + for(auto ins : iterator_for(*p1.get_main_module())) + { + if(ins->name() == "gpu::dynamic_code_object_op") + { + found = true; + EXPECT(ins->module_inputs() == pw_module_inputs); + } + } + EXPECT(found); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/module_test.cpp b/test/module_test.cpp index 87ab9019e13..9ce615cf4bc 100644 --- a/test/module_test.cpp +++ b/test/module_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -803,6 +803,29 @@ TEST_CASE(add_params) EXPECT(m1.get_parameter("x1") == map_ins[add]); } +TEST_CASE(with_static_shapes) +{ + auto create_module = [](const std::vector& input_shapes) { + migraphx::module m; + auto x = m.add_parameter("x", input_shapes[0]); + auto y = m.add_parameter("y", input_shapes[1]); + auto add = m.add_instruction(migraphx::make_op("add"), x, y); + auto reduce_mean = + m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), add); + m.add_return({reduce_mean}); + return m; + }; + auto dyn_shape = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 8}}}; + auto dyn_mod = create_module({dyn_shape, dyn_shape}); + + auto static_shape = migraphx::shape{migraphx::shape::float_type, {2, 5}}; + auto static_mod = create_module({static_shape, static_shape}); + std::unordered_map static_arg_shapes{{"x", static_shape}, + {"y", static_shape}}; + + EXPECT(dyn_mod.with_static_shapes(static_arg_shapes).sort() == static_mod.sort()); +} + TEST_CASE(linear_graph_sort) { // diff --git a/test/verify/main.cpp b/test/verify/main.cpp index dde235e3094..0db47d92b24 100644 --- a/test/verify/main.cpp +++ b/test/verify/main.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -142,7 +142,10 @@ int main(int argc, const char* argv[]) "test_bit_cast", "test_bit_cast", "test_bit_cast", - "test_bit_cast"}); + "test_bit_cast", + "test_dynamic_pointwise<4, 16, 24>", + "test_dynamic_pointwise<2, 8, 4>", + "test_dynamic_pointwise<3, 10, 13>"}); rv.disable_test_for("gpu", { // These passes on MI300 but fails on others, same issue as CPU. diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp index 6a26a46d810..6197af4c6c8 100644 --- a/test/verify/run_verify.cpp +++ b/test/verify/run_verify.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -205,9 +205,10 @@ void run_verify::verify(const program_info& pi) const { if(x.second.dynamic()) { - // create static shape using maximum dimensions - migraphx::shape static_shape{x.second.type(), x.second.max_lens()}; - m[x.first] = migraphx::generate_argument(static_shape, get_hash(x.first)); + auto static_shape = contains(pi.test_dims, x.first) + ? pi.test_dims.at(x.first) + : migraphx::shape{x.second.type(), x.second.max_lens()}; + m[x.first] = migraphx::generate_argument(static_shape, get_hash(x.first)); } else { diff --git a/test/verify/test_dynamic_pointwise.cpp b/test/verify/test_dynamic_pointwise.cpp new file mode 100644 index 00000000000..7f744f4fc70 --- /dev/null +++ b/test/verify/test_dynamic_pointwise.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include +#include + +template +struct test_dynamic_pointwise : verify_program> +{ + migraphx::program create_program() const + { + migraphx::shape s{migraphx::shape::float_type, {{2, 4}, {8, 16}, {4, 24}}}; + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", s); + auto y = mm->add_parameter("y", s); + auto relu = mm->add_instruction(migraphx::make_op("relu"), x); + auto add = mm->add_instruction(migraphx::make_op("add"), relu, y); + mm->add_return({add}); + return p; + }; + + std::unordered_map get_test_dims() const + { + return {{"x", migraphx::shape{migraphx::shape::float_type, {TestDims...}}}, + {"y", migraphx::shape{migraphx::shape::float_type, {TestDims...}}}}; + } +}; + +template struct test_dynamic_pointwise<4, 16, 24>; // maxes +template struct test_dynamic_pointwise<2, 8, 4>; // mins +template struct test_dynamic_pointwise<3, 10, 13>; // inbetween diff --git a/test/verify/verify_program.hpp b/test/verify/verify_program.hpp index 39d2392bec8..20b42e1cfef 100644 --- a/test/verify/verify_program.hpp +++ b/test/verify/verify_program.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,6 +38,7 @@ struct program_info std::size_t tolerance; std::function get_program; migraphx::compile_options compile_options; + std::unordered_map test_dims; }; void register_program_info(const program_info& pi); @@ -66,6 +67,7 @@ struct register_verify_program_action pi.tolerance = x.get_tolerance(); pi.get_program = [x] { return x.create_program(); }; pi.compile_options = x.get_compile_options(); + pi.test_dims = x.get_test_dims(); register_program_info(pi); } }; @@ -79,6 +81,7 @@ struct verify_program : auto_register_verify_program std::string section() const { return "general"; }; migraphx::compile_options get_compile_options() const { return migraphx::compile_options{}; }; std::size_t get_tolerance() const { return 80; }; + std::unordered_map get_test_dims() const { return {}; }; }; #endif