diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index a9d2af7d6d..003da07e73 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -176,7 +176,8 @@ def restrict_fusion_to_newly_created_maps( ) -> bool: return any(new_entry in new_maps for new_entry in [map_entry_1, map_entry_2]) - # Using the callback to restrict the fusing + # Now try to fuse the maps together, but restrict them that at least one map + # needs to be new. sdfg.apply_transformations_repeated( [ gtx_transformations.MapFusionSerial( @@ -192,8 +193,20 @@ def restrict_fusion_to_newly_created_maps( validate_all=validate_all, ) - # Now we have to find the maps that were not fused. We rely here on the fact - # that at least one of the map that is involved in fusing still exists. + # This is a gross hack, but it is needed, for the following reasons: + # - The transients have C order while the non-transients have (most + # likely) FORTRAN order. So there is not an unique stride dimension. + # - The newly created maps have names that does not reflect GT4Py dimensions, + # thus we can not use `gt_set_iteration_order()`. + # For these reasons we do the simplest thing, which is assuming that the maps + # are created in C order and we must make them in FORTRAN order, which means + # just swapping the order of the map parameters. + # We further assume here, that we only have to process the maps that we have + # newly created. + # NOTE: We can stop relying on this once [PR#1913](https://github.com/GridTools/gt4py/pull/1913) + # Has been merged, which is currently blocked by a DaCe PR that has not been + # merged. + maps_to_modify: set[dace_nodes.MapEntry] = set() for nsdfg in sdfg.all_sdfgs_recursive(): for state in nsdfg.states(): @@ -202,17 +215,13 @@ def restrict_fusion_to_newly_created_maps( continue if map_entry in new_maps: maps_to_modify.add(map_entry) - assert 0 < len(maps_to_modify) <= len(new_maps) - # This is a gross hack, but it is needed, for the following reasons: - # - The transients have C order while the non-transients have (most - # likely) FORTRAN order. So there is not an unique stride dimension. - # - The newly created maps have names that does not reflect GT4Py dimensions, - # thus we can not use `gt_set_iteration_order()`. - # For these reasons we do the simplest thing, which is assuming that the maps - # are created in C order and we must make them in FORTRAN order, which means - # just swapping the order of the map parameters. - # TODO(phimuell): Do it properly. + # We did not found any of the newly created Map. Thus we **hope** that all new + # Maps have been integrated into other Maps, that have the correct names. + # But as written above, this is a gross hack! + if len(maps_to_modify) == 0: + return sdfg + for me_to_modify in maps_to_modify: map_to_modify: dace_nodes.Map = me_to_modify.map map_to_modify.params = list(reversed(map_to_modify.params))