Skip to content

Commit dd893ce

Browse files
committed
feat: add StridedMemoryView.from_array_interface
1 parent 9d37149 commit dd893ce

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,15 @@ cdef class StridedMemoryView:
153153
view_as_cai(obj, stream_ptr, buf)
154154
return buf
155155

156+
@classmethod
157+
def from_array_interface(cls, obj: object) -> StridedMemoryView:
158+
cdef StridedMemoryView buf
159+
with warnings.catch_warnings():
160+
warnings.simplefilter("ignore")
161+
buf = cls()
162+
view_as_array_interface(obj, buf)
163+
return buf
164+
156165
@classmethod
157166
def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView:
158167
if check_has_dlpack(obj):
@@ -597,6 +606,23 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
597606
return buf
598607

599608

609+
cpdef StridedMemoryView view_as_array_interface(obj, view=None):
610+
cdef dict data = obj.__array_interface__
611+
if data["version"] < 3:
612+
raise BufferError("only NumPy Array Interface v3 or above is supported")
613+
if data.get("mask") is not None:
614+
raise BufferError("mask is not supported")
615+
616+
cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
617+
buf.exporting_obj = obj
618+
buf.metadata = data
619+
buf.dl_tensor = NULL
620+
buf.ptr, buf.readonly = data["data"]
621+
buf.is_device_accessible = False
622+
buf.device_id = handle_return(driver.cuCtxGetDevice())
623+
return buf
624+
625+
600626
def args_viewable_as_strided_memory(tuple arg_indices):
601627
"""
602628
Decorator to create proxy objects to :obj:`StridedMemoryView` for the

cuda_core/tests/test_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from cuda.core import Device
1919
from cuda.core._layout import _StridedLayout
2020
from cuda.core.utils import StridedMemoryView, args_viewable_as_strided_memory
21+
from pytest import param
2122

2223

2324
def test_cast_to_3_tuple_success():
@@ -460,3 +461,55 @@ def test_struct_array():
460461
# full dtype information doesn't seem to be preserved due to use of type strings,
461462
# which are lossy, e.g., dtype([("a", "int")]).str == "V8"
462463
assert smv.dtype == np.dtype(f"V{x.itemsize}")
464+
465+
466+
@pytest.mark.parametrize(
467+
("x", "expected_dtype"),
468+
[
469+
# 1D arrays with different dtypes
470+
param(np.array([1, 2, 3], dtype=np.int32), "int32", id="1d-int32"),
471+
param(np.array([1.0, 2.0, 3.0], dtype=np.float64), "float64", id="1d-float64"),
472+
param(np.array([1 + 2j, 3 + 4j], dtype=np.complex128), "complex128", id="1d-complex128"),
473+
param(np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64), "complex64", id="1d-complex64"),
474+
param(np.array([1, 2, 3, 4, 5], dtype=np.uint8), "uint8", id="1d-uint8"),
475+
param(np.array([1, 2], dtype=np.int64), "int64", id="1d-int64"),
476+
param(np.array([100, 200, 300], dtype=np.int16), "int16", id="1d-int16"),
477+
param(np.array([1000, 2000, 3000], dtype=np.uint16), "uint16", id="1d-uint16"),
478+
param(np.array([10000, 20000, 30000], dtype=np.uint64), "uint64", id="1d-uint64"),
479+
# 2D arrays - C-contiguous
480+
param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32), "int32", id="2d-c-int32"),
481+
param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), "float32", id="2d-c-float32"),
482+
# 2D arrays - Fortran-contiguous
483+
param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32, order="F"), "int32", id="2d-f-int32"),
484+
param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64, order="F"), "float64", id="2d-f-float64"),
485+
# 3D arrays
486+
param(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int32), "int32", id="3d-int32"),
487+
param(np.ones((2, 3, 4), dtype=np.float64), "float64", id="3d-float64"),
488+
# Sliced/strided arrays
489+
param(np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)[::2], "int32", id="1d-strided-int32"),
490+
param(np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float64)[:, ::2], "float64", id="2d-strided-float64"),
491+
param(np.arange(20, dtype=np.int32).reshape(4, 5)[::2, ::2], "int32", id="2d-strided-2x2-int32"),
492+
# Scalar (0-D array)
493+
param(np.array(42, dtype=np.int32), "int32", id="scalar-int32"),
494+
param(np.array(3.14, dtype=np.float64), "float64", id="scalar-float64"),
495+
# Empty arrays
496+
param(np.array([], dtype=np.int32), "int32", id="empty-1d-int32"),
497+
param(np.empty((0, 3), dtype=np.float64), "float64", id="empty-2d-float64"),
498+
# Single element
499+
param(np.array([1], dtype=np.int32), "int32", id="single-element"),
500+
# Structured dtype
501+
param(np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")]), "V12", id="structured-dtype"),
502+
],
503+
)
504+
def test_from_array_interface(x, expected_dtype):
505+
smv = StridedMemoryView.from_array_interface(x)
506+
assert smv.size == x.size
507+
assert smv.dtype == np.dtype(expected_dtype)
508+
assert smv.shape == x.shape
509+
assert smv.ptr == x.ctypes.data
510+
assert smv.is_device_accessible is False
511+
assert smv.exporting_obj is x
512+
assert smv.readonly is not x.flags.writeable
513+
# Check strides
514+
strides_in_counts = convert_strides_to_counts(x.strides, x.dtype.itemsize)
515+
assert (x.flags.c_contiguous and smv.strides is None) or smv.strides == strides_in_counts

0 commit comments

Comments
 (0)