From 51b941f512ce1457cc9282f62c4ff90e19220370 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 28 Apr 2025 13:16:21 +0200 Subject: [PATCH] Fixes an issue in the expander for not supported Memlets. We have to expand some Memlets into copy Maps, because there is no suitable call to `cudaMemcpy*`, because we have to set their order right. But this is a gross hack, it should be solved once [PR#1913](https://github.com/GridTools/gt4py/pull/1913) has been merged, but for that DaCe [PR#1976](https://github.com/spcl/dace/pull/1976) has to be merged. The error was that there was the assumption that at least one of these newly created Maps survived. However, this does not seem to be the case, depending in which order MapFusion is applied. Thus the assert was removed. We now assume that if we do not find a newly created MapEntry that is was integrated into a Map that was already present, we further assume that the Map parameter of that Map were used and that they are correct. --- .../runners/dace/transformations/gpu_utils.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index a9d2af7d6d..003da07e73 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -176,7 +176,8 @@ def restrict_fusion_to_newly_created_maps( ) -> bool: return any(new_entry in new_maps for new_entry in [map_entry_1, map_entry_2]) - # Using the callback to restrict the fusing + # Now try to fuse the maps together, but restrict them that at least one map + # needs to be new. sdfg.apply_transformations_repeated( [ gtx_transformations.MapFusionSerial( @@ -192,8 +193,20 @@ def restrict_fusion_to_newly_created_maps( validate_all=validate_all, ) - # Now we have to find the maps that were not fused. We rely here on the fact - # that at least one of the map that is involved in fusing still exists. + # This is a gross hack, but it is needed, for the following reasons: + # - The transients have C order while the non-transients have (most + # likely) FORTRAN order. So there is not an unique stride dimension. + # - The newly created maps have names that does not reflect GT4Py dimensions, + # thus we can not use `gt_set_iteration_order()`. + # For these reasons we do the simplest thing, which is assuming that the maps + # are created in C order and we must make them in FORTRAN order, which means + # just swapping the order of the map parameters. + # We further assume here, that we only have to process the maps that we have + # newly created. + # NOTE: We can stop relying on this once [PR#1913](https://github.com/GridTools/gt4py/pull/1913) + # Has been merged, which is currently blocked by a DaCe PR that has not been + # merged. + maps_to_modify: set[dace_nodes.MapEntry] = set() for nsdfg in sdfg.all_sdfgs_recursive(): for state in nsdfg.states(): @@ -202,17 +215,13 @@ def restrict_fusion_to_newly_created_maps( continue if map_entry in new_maps: maps_to_modify.add(map_entry) - assert 0 < len(maps_to_modify) <= len(new_maps) - # This is a gross hack, but it is needed, for the following reasons: - # - The transients have C order while the non-transients have (most - # likely) FORTRAN order. So there is not an unique stride dimension. - # - The newly created maps have names that does not reflect GT4Py dimensions, - # thus we can not use `gt_set_iteration_order()`. - # For these reasons we do the simplest thing, which is assuming that the maps - # are created in C order and we must make them in FORTRAN order, which means - # just swapping the order of the map parameters. - # TODO(phimuell): Do it properly. + # We did not found any of the newly created Map. Thus we **hope** that all new + # Maps have been integrated into other Maps, that have the correct names. + # But as written above, this is a gross hack! + if len(maps_to_modify) == 0: + return sdfg + for me_to_modify in maps_to_modify: map_to_modify: dace_nodes.Map = me_to_modify.map map_to_modify.params = list(reversed(map_to_modify.params))