From 18f2e26f817f0513fb2ad10c591d6909348cfe7c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 17 Mar 2025 11:21:29 +0100 Subject: [PATCH] Adjusted the handling of non standard GPU Memlets. Once [PR#1976](https://github.com/spcl/dace/pull/1976) is merged in DaCe the code generator is able to handle more Memlets directly as Cuda `memcpy()` calls. This PR modifies the GPU transformation of GT4Py in such a way that these Memlets are no longer transformed into Maps. However, it can only be merged if the DaCe dependency was bumped to a version that includes PR#1976! --- .../runners/dace/transformations/gpu_utils.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index e1f105f0ef..8f487e30c6 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -248,7 +248,9 @@ def _gt_expand_non_standard_memlets_sdfg( ) -> set[dace_nodes.MapEntry]: """Implementation of `_gt_expand_non_standard_memlets()` that process a single SDFG.""" new_maps: set[dace_nodes.MapEntry] = set() - # The implementation is based on DaCe's code generator. + # The implementation is based on DaCe's code generator, see `dace/codegen/targets/cuda.py` + # in the function `preprocess()` + # NOTE: This implementation needs a DaCe version that includes https://github.com/spcl/dace/pull/1976 for state in sdfg.states(): for e in state.edges(): # We are only interested in edges that connects two access nodes of GPU memory. @@ -269,16 +271,9 @@ def _gt_expand_non_standard_memlets_sdfg( if dims == 1: continue elif dims == 2: - if src_strides[-1] != 1 or dst_strides[-1] != 1: - try: - is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] - is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] - except (TypeError, ValueError): - is_src_cont = False - is_dst_cont = False - if is_src_cont and is_dst_cont: - continue - else: + is_fortran_order = src_strides[0] == 1 and dst_strides[0] == 1 + is_c_order = src_strides[-1] == 1 and dst_strides[-1] == 1 + if is_c_order or is_fortran_order: continue elif dims > 2: if not (src_strides[-1] != 1 or dst_strides[-1] != 1):