Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8deac8d
enable dyn shapes for pointwise and reduce fusions
shivadbhavsar Jan 8, 2026
d83f2a0
format
shivadbhavsar Jan 9, 2026
370e386
License update
shivadbhavsar Jan 9, 2026
0061b0a
add dynamic code object op
shivadbhavsar Jan 15, 2026
67119d8
add test for module helper function
shivadbhavsar Jan 16, 2026
46b7adc
fix bug in pw_broadcast_pw and add test case
shivadbhavsar Jan 19, 2026
7689b4c
fix fuse_reduce test
shivadbhavsar Jan 19, 2026
366f7c2
fix dyn shape constructor in tests
shivadbhavsar Jan 19, 2026
a821a23
cleanup add_reduce_module func for tests
shivadbhavsar Jan 19, 2026
fb56f9c
Merge branch 'develop' into dyn_compute_shapes
shivadbhavsar Jan 19, 2026
d860589
Merge branch 'develop' into dyn_compute_shapes
shivadbhavsar Jan 20, 2026
86d09f4
format
shivadbhavsar Jan 20, 2026
242d36f
Merge branch 'develop' into dyn_compute_shapes
shivadbhavsar Jan 22, 2026
caf41fb
Merge remote-tracking branch 'origin/develop' into dyn_code_obj
shivadbhavsar Jan 23, 2026
aa6a4b6
Merge remote-tracking branch 'origin/dyn_compute_shapes' into dyn_cod…
shivadbhavsar Jan 23, 2026
6ee3e5c
add verify test
shivadbhavsar Jan 23, 2026
a5275d7
format
shivadbhavsar Jan 26, 2026
1696bf9
tidy
shivadbhavsar Jan 26, 2026
eed8c42
license
shivadbhavsar Jan 26, 2026
8b14d8f
Merge branch 'develop' into dyn_code_obj
shivadbhavsar Jan 26, 2026
47d191e
Merge remote-tracking branch 'origin/develop' into dyn_code_obj
shivadbhavsar Jan 26, 2026
2e763fa
tidy + format
shivadbhavsar Jan 26, 2026
3a98e06
disable dynamic verify tests for cpu
shivadbhavsar Jan 27, 2026
000a93c
Merge branch 'develop' into dyn_code_obj
shivadbhavsar Jan 27, 2026
241a10c
Merge remote-tracking branch 'origin/develop' into dyn_code_obj
shivadbhavsar Jan 27, 2026
7a6e28b
review updates and cleanup
shivadbhavsar Jan 27, 2026
c8c69a0
tidy
shivadbhavsar Jan 27, 2026
8a523a4
Merge remote-tracking branch 'origin/develop' into dyn_code_obj
shivadbhavsar Jan 28, 2026
0fd36f6
clean up code object op and address review comments
shivadbhavsar Jan 28, 2026
d27ec1e
Merge remote-tracking branch 'origin/develop' into dyn_code_obj
shivadbhavsar Jan 28, 2026
f238dd4
cppcheck fix
shivadbhavsar Jan 29, 2026
121d9d1
tidy
shivadbhavsar Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/include/migraphx/module.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -323,6 +323,12 @@ struct MIGRAPHX_EXPORT module
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;

std::vector<module_ref> get_sub_modules(bool shallow = false) const;

/* Creates a new module with the same instructions but with different input parameter shapes.
Returns the new module by value without modifying the original.
*/
module with_static_shapes(const std::unordered_map<std::string, shape>& input_shapes);

/* sorts the module in topological order aka reverse-post order (RPO) DFS order
it takes last instruction or @return as the root and walks back the graph and moves inputs
of the each instruction such that it appears before the instruction itself.
Expand Down
32 changes: 32 additions & 0 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/module.hpp>
#include <migraphx/bit_signal.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/target.hpp>
Expand Down Expand Up @@ -1467,6 +1468,37 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const
return vec_modules;
}

module module::with_static_shapes(const std::unordered_map<std::string, shape>& input_shapes)
{
// This routine creates a new module with the same instructions but with different input shapes.
// The sequence of instructions (operators and interconnectivity) is copied, but all input
// parameter shapes are replaced with new "input_shapes".

// ensure input_shapes is the same length as the parameters.
auto param_names = this->get_parameter_names();
assert(param_names.size() == input_shapes.size());

module new_mod;
std::unordered_map<instruction_ref, instruction_ref> ins_map;

// create parameters with new shapes in new_mod and fill ins_map for params
for(auto ins : iterator_for(*this))
{
if(ins->name() == "@param")
{
auto pname = any_cast<builtin::param>(ins->get_operator()).parameter;
assert(input_shapes.count(pname) > 0);
ins_map[ins] = new_mod.add_parameter(pname, input_shapes.at(pname));
}
}

// Copy remaining instructions in order
auto ret = new_mod.insert_instructions(new_mod.end(), this, &ins_map);
new_mod.add_return(ret);

return new_mod;
}

module& module::sort()
{
if(this->begin() == this->end())
Expand Down
142 changes: 141 additions & 1 deletion src/targets/gpu/compile_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,149 @@ struct precompile_op
return shapes.size() - 1;
}
};

MIGRAPHX_REGISTER_OP(precompile_op);

struct dynamic_code_object_op
{
operation pre_op = precompile_op{};
mutable std::shared_ptr<module> cache_mod;
mutable std::shared_ptr<std::vector<shape>> cache_input_shapes;
mutable std::shared_ptr<shape> cache_static_output_shape;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.pre_op, "pre_op"));
}

std::string name() const { return "gpu::dynamic_code_object_op"; }

shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>& mods) const
{
return pre_op.compute_shape(inputs, mods);
}

std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
std::unordered_map<std::string, argument> build_param_map(const std::vector<argument>& args,
const_module_ref mod) const
{
auto pnames = mod->get_parameter_names();
assert(pnames.size() == args.size());
std::unordered_map<std::string, argument> param_map;
std::transform(pnames.begin(),
pnames.end(),
args.begin(),
std::inserter(param_map, param_map.end()),
[](const auto& name, const auto& arg) { return std::make_pair(name, arg); });
return param_map;
}
argument compute(context& ctx,
const shape&,
const std::vector<argument>& args,
const std::vector<module_ref>& module_args,
const std::function<std::vector<argument>(
module_ref&, const std::unordered_map<std::string, argument>&)>& run) const
{
auto static_args = std::vector<argument>{args.begin(), args.end()};
auto output_arg = static_args.back();

if(cache_mod and *cache_input_shapes == to_shapes(args))
{
static_args[static_args.size() - 1] = output_arg.reshape(*cache_static_output_shape);
auto* mod = cache_mod.get();
auto param_map = build_param_map(static_args, mod);
auto results = run(mod, param_map);
if(results.size() > 1)
return results;
return results.front();
}

if(output_arg.get_shape().dynamic())
{
auto out_shape = pre_op.compute_shape(to_shapes(static_args), module_args);
static_args[static_args.size() - 1] = output_arg.reshape(out_shape);
}

// Rewrite submodule without dynamic shapes to be used as the IR for compilation
module static_submod;
if(not module_args.empty())
{
auto pnames = module_args.front()->get_parameter_names();
std::unordered_map<std::string, shape> mod_arg_shapes;
std::transform(pnames.begin(),
pnames.end(),
args.begin(),
std::inserter(mod_arg_shapes, mod_arg_shapes.end()),
[&](const auto& name, const auto& arg) {
return std::make_pair(name, arg.get_shape());
});
static_submod = module_args.front()->with_static_shapes(mod_arg_shapes);
static_submod.set_bypass(true);
}

// Create runtime module which will be compiled and cached
auto name = "runtime_mod:" + module_args.front()->name();
auto runtime_mod = module(name);
std::vector<instruction_ref> args_ins;
std::vector<size_t> idx(static_args.size());
std::iota(std::begin(idx), std::end(idx), 0);
std::transform(static_args.begin(),
static_args.end(),
idx.begin(),
std::back_inserter(args_ins),
[&](const auto& arg, const auto& i) {
return runtime_mod.add_parameter(name + ":x" + std::to_string(i),
arg.get_shape());
});
instruction_ref ins;
if(not module_args.empty())
{
ins = runtime_mod.add_instruction(pre_op, args_ins, {&static_submod});
}
else
{
ins = runtime_mod.add_instruction(pre_op, args_ins);
}
runtime_mod.add_return({ins});

// Compile ins and replace with a compiled code object op
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
auto config = get_tuning_config(ctx, ins, preop, false);
value solution = value{};
if(config.has_value())
{
solution = config->solutions.front();
}
auto compiled_op = compile(ctx, ins, preop, solution);
compiled_op.replace(runtime_mod, ins);
run_passes(runtime_mod, {dead_code_elimination{}});

// Finalize the module before execution
std::vector<migraphx::context> contexts = {migraphx::context(ctx)};
runtime_mod.finalize(contexts);

// Update cache
// TODO: This will be updated to store compiled code objects for all encountered shapes
cache_mod = std::make_shared<module>(runtime_mod);
cache_input_shapes = std::make_shared<std::vector<shape>>(to_shapes(args));
cache_static_output_shape = std::make_shared<shape>(static_args.back().get_shape());

// Build param_map based on ACTUAL parameters that exist
module_ref runtime_mod_ref = &runtime_mod;
auto param_map = build_param_map(static_args, runtime_mod_ref);

auto results = run(runtime_mod_ref, param_map);

if(results.size() > 1)
return results;
return results.front();
}
};
MIGRAPHX_REGISTER_OP(dynamic_code_object_op);

struct compiled_result
{
compiler_replace replace;
Expand Down
30 changes: 25 additions & 5 deletions src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -182,6 +182,7 @@ struct miopen_apply
else if(has_compiler_for(it->name()))
{
check_shape(s, insert_precompile_op(it));
check_shape(s, insert_dynamic_code_object_op(it));
}
else if(attrs.contains("target"))
{
Expand Down Expand Up @@ -240,6 +241,20 @@ struct miopen_apply
ins->module_inputs());
}

instruction_ref insert_dynamic_code_object_op(instruction_ref ins) const
{
assert(ins->get_operator().name() == "gpu::precompile_op");

if(not ins->get_shape().dynamic())
return ins;

return mod->replace_instruction(
ins,
make_op("gpu::dynamic_code_object_op", {{"pre_op", to_value(ins->get_operator())}}),
ins->inputs(),
ins->module_inputs());
}

instruction_ref insert_allocation(instruction_ref ins, const shape& s) const
{
return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}}));
Expand Down Expand Up @@ -334,7 +349,8 @@ struct miopen_apply
static bool use_miopen_pooling(instruction_ref ins)
{
if(enabled(MIGRAPHX_DISABLE_MIOPEN_POOLING{}) or
not contains({shape::float_type, shape::half_type}, ins->get_shape().type()))
not contains({shape::float_type, shape::half_type}, ins->get_shape().type()) or
ins->get_shape().dynamic())
return false;
auto&& op = ins->get_operator();
auto op_val = op.to_value();
Expand All @@ -355,15 +371,19 @@ struct miopen_apply
{
apply_map.emplace("pooling", [=](instruction_ref ins) {
if(not use_miopen_pooling(ins))
return insert_precompile_op(ins);
{
auto preop = insert_precompile_op(ins);
return insert_dynamic_code_object_op(preop);
}
#if MIGRAPHX_USE_MIOPEN
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
auto&& op = ins->get_operator();
refs.push_back(output);
return mod->replace_instruction(ins, make_op("gpu::pooling", op.to_value()), refs);
#else
return insert_precompile_op(ins);
#else
auto preop = insert_precompile_op(ins);
return insert_dynamic_code_object_op(preop);
#endif
});
}
Expand Down
68 changes: 68 additions & 0 deletions test/gpu/dynamic_code_object_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/gpu/lowering.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/module.hpp>
#include <test.hpp>
#include <pointwise.hpp>

static void run_lowering(migraphx::program& p, bool offload_copy = false)
{
auto ctx = migraphx::gpu::context{};
migraphx::run_passes(*p.get_main_module(), {migraphx::gpu::lowering{&ctx, offload_copy}});
}

TEST_CASE(dynamic_code_object_op)
{
migraphx::shape s{migraphx::shape::float_type, {{1, 3}, {2, 4}, {6, 6}}};
migraphx::program p1;
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);

auto pw = add_pointwise(p1, "main:pointwise0", {a, b}, single_pointwise("add"));
auto pw_module_inputs = pw->module_inputs();

mm->add_return({pw});

run_lowering(p1);

bool found = false;
for(auto ins : iterator_for(*p1.get_main_module()))
{
if(ins->name() == "gpu::dynamic_code_object_op")
{
found = true;
EXPECT(ins->module_inputs() == pw_module_inputs);
}
}
EXPECT(found);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
25 changes: 24 additions & 1 deletion test/module_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -803,6 +803,29 @@ TEST_CASE(add_params)
EXPECT(m1.get_parameter("x1") == map_ins[add]);
}

TEST_CASE(with_static_shapes)
{
auto create_module = [](const std::vector<migraphx::shape>& input_shapes) {
migraphx::module m;
auto x = m.add_parameter("x", input_shapes[0]);
auto y = m.add_parameter("y", input_shapes[1]);
auto add = m.add_instruction(migraphx::make_op("add"), x, y);
auto reduce_mean =
m.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {1}}}), add);
m.add_return({reduce_mean});
return m;
};
auto dyn_shape = migraphx::shape{migraphx::shape::float_type, {{1, 4}, {4, 8}}};
auto dyn_mod = create_module({dyn_shape, dyn_shape});

auto static_shape = migraphx::shape{migraphx::shape::float_type, {2, 5}};
auto static_mod = create_module({static_shape, static_shape});
std::unordered_map<std::string, migraphx::shape> static_arg_shapes{{"x", static_shape},
{"y", static_shape}};

EXPECT(dyn_mod.with_static_shapes(static_arg_shapes).sort() == static_mod.sort());
}

TEST_CASE(linear_graph_sort)
{
//
Expand Down
Loading
Loading