diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 003da07e73..d121e46012 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -258,7 +258,9 @@ def _gt_expand_non_standard_memlets_sdfg( ) -> set[dace_nodes.MapEntry]: """Implementation of `_gt_expand_non_standard_memlets()` that process a single SDFG.""" new_maps: set[dace_nodes.MapEntry] = set() - # The implementation is based on DaCe's code generator. + # The implementation is based on DaCe's code generator, see `dace/codegen/targets/cuda.py` + # in the function `preprocess()` + # NOTE: This implementation needs a DaCe version that includes https://github.com/spcl/dace/pull/1976 for state in sdfg.states(): for e in state.edges(): # We are only interested in edges that connects two access nodes of GPU memory. @@ -279,16 +281,9 @@ def _gt_expand_non_standard_memlets_sdfg( if dims == 1: continue elif dims == 2: - if src_strides[-1] != 1 or dst_strides[-1] != 1: - try: - is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] - is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] - except (TypeError, ValueError): - is_src_cont = False - is_dst_cont = False - if is_src_cont and is_dst_cont: - continue - else: + is_fortran_order = src_strides[0] == 1 and dst_strides[0] == 1 + is_c_order = src_strides[-1] == 1 and dst_strides[-1] == 1 + if is_c_order or is_fortran_order: continue elif dims > 2: if not (src_strides[-1] != 1 or dst_strides[-1] != 1):