Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def restrict_fusion_to_newly_created_maps(
) -> bool:
return any(new_entry in new_maps for new_entry in [map_entry_1, map_entry_2])

# Using the callback to restrict the fusing
# Now try to fuse the maps together, but restrict them that at least one map
# needs to be new.
sdfg.apply_transformations_repeated(
[
gtx_transformations.MapFusionSerial(
Expand All @@ -192,8 +193,20 @@ def restrict_fusion_to_newly_created_maps(
validate_all=validate_all,
)

# Now we have to find the maps that were not fused. We rely here on the fact
# that at least one of the map that is involved in fusing still exists.
# This is a gross hack, but it is needed, for the following reasons:
# - The transients have C order while the non-transients have (most
# likely) FORTRAN order. So there is not an unique stride dimension.
# - The newly created maps have names that does not reflect GT4Py dimensions,
# thus we can not use `gt_set_iteration_order()`.
# For these reasons we do the simplest thing, which is assuming that the maps
# are created in C order and we must make them in FORTRAN order, which means
# just swapping the order of the map parameters.
# We further assume here, that we only have to process the maps that we have
# newly created.
# NOTE: We can stop relying on this once [PR#1913](https://github.com/GridTools/gt4py/pull/1913)
# Has been merged, which is currently blocked by a DaCe PR that has not been
# merged.

maps_to_modify: set[dace_nodes.MapEntry] = set()
for nsdfg in sdfg.all_sdfgs_recursive():
for state in nsdfg.states():
Expand All @@ -202,17 +215,13 @@ def restrict_fusion_to_newly_created_maps(
continue
if map_entry in new_maps:
maps_to_modify.add(map_entry)
assert 0 < len(maps_to_modify) <= len(new_maps)

# This is a gross hack, but it is needed, for the following reasons:
# - The transients have C order while the non-transients have (most
# likely) FORTRAN order. So there is not an unique stride dimension.
# - The newly created maps have names that does not reflect GT4Py dimensions,
# thus we can not use `gt_set_iteration_order()`.
# For these reasons we do the simplest thing, which is assuming that the maps
# are created in C order and we must make them in FORTRAN order, which means
# just swapping the order of the map parameters.
# TODO(phimuell): Do it properly.
# We did not found any of the newly created Map. Thus we **hope** that all new
# Maps have been integrated into other Maps, that have the correct names.
# But as written above, this is a gross hack!
if len(maps_to_modify) == 0:
return sdfg

for me_to_modify in maps_to_modify:
map_to_modify: dace_nodes.Map = me_to_modify.map
map_to_modify.params = list(reversed(map_to_modify.params))
Expand Down