Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ def _gt_expand_non_standard_memlets_sdfg(
) -> set[dace_nodes.MapEntry]:
"""Implementation of `_gt_expand_non_standard_memlets()` that process a single SDFG."""
new_maps: set[dace_nodes.MapEntry] = set()
# The implementation is based on DaCe's code generator.
# The implementation is based on DaCe's code generator, see `dace/codegen/targets/cuda.py`
# in the function `preprocess()`
# NOTE: This implementation needs a DaCe version that includes https://github.com/spcl/dace/pull/1976
for state in sdfg.states():
for e in state.edges():
# We are only interested in edges that connects two access nodes of GPU memory.
Expand All @@ -279,16 +281,9 @@ def _gt_expand_non_standard_memlets_sdfg(
if dims == 1:
continue
elif dims == 2:
if src_strides[-1] != 1 or dst_strides[-1] != 1:
try:
is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1]
is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1]
except (TypeError, ValueError):
is_src_cont = False
is_dst_cont = False
if is_src_cont and is_dst_cont:
continue
else:
is_fortran_order = src_strides[0] == 1 and dst_strides[0] == 1
is_c_order = src_strides[-1] == 1 and dst_strides[-1] == 1
if is_c_order or is_fortran_order:
continue
elif dims > 2:
if not (src_strides[-1] != 1 or dst_strides[-1] != 1):
Expand Down