|
18 | 18 | from cuda.core import Device |
19 | 19 | from cuda.core._layout import _StridedLayout |
20 | 20 | from cuda.core.utils import StridedMemoryView, args_viewable_as_strided_memory |
| 21 | +from pytest import param |
21 | 22 |
|
22 | 23 |
|
23 | 24 | def test_cast_to_3_tuple_success(): |
@@ -460,3 +461,55 @@ def test_struct_array(): |
460 | 461 | # full dtype information doesn't seem to be preserved due to use of type strings, |
461 | 462 | # which are lossy, e.g., dtype([("a", "int")]).str == "V8" |
462 | 463 | assert smv.dtype == np.dtype(f"V{x.itemsize}") |
| 464 | + |
| 465 | + |
| 466 | +@pytest.mark.parametrize( |
| 467 | + ("x", "expected_dtype"), |
| 468 | + [ |
| 469 | + # 1D arrays with different dtypes |
| 470 | + param(np.array([1, 2, 3], dtype=np.int32), "int32", id="1d-int32"), |
| 471 | + param(np.array([1.0, 2.0, 3.0], dtype=np.float64), "float64", id="1d-float64"), |
| 472 | + param(np.array([1 + 2j, 3 + 4j], dtype=np.complex128), "complex128", id="1d-complex128"), |
| 473 | + param(np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64), "complex64", id="1d-complex64"), |
| 474 | + param(np.array([1, 2, 3, 4, 5], dtype=np.uint8), "uint8", id="1d-uint8"), |
| 475 | + param(np.array([1, 2], dtype=np.int64), "int64", id="1d-int64"), |
| 476 | + param(np.array([100, 200, 300], dtype=np.int16), "int16", id="1d-int16"), |
| 477 | + param(np.array([1000, 2000, 3000], dtype=np.uint16), "uint16", id="1d-uint16"), |
| 478 | + param(np.array([10000, 20000, 30000], dtype=np.uint64), "uint64", id="1d-uint64"), |
| 479 | + # 2D arrays - C-contiguous |
| 480 | + param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32), "int32", id="2d-c-int32"), |
| 481 | + param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), "float32", id="2d-c-float32"), |
| 482 | + # 2D arrays - Fortran-contiguous |
| 483 | + param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32, order="F"), "int32", id="2d-f-int32"), |
| 484 | + param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64, order="F"), "float64", id="2d-f-float64"), |
| 485 | + # 3D arrays |
| 486 | + param(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int32), "int32", id="3d-int32"), |
| 487 | + param(np.ones((2, 3, 4), dtype=np.float64), "float64", id="3d-float64"), |
| 488 | + # Sliced/strided arrays |
| 489 | + param(np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)[::2], "int32", id="1d-strided-int32"), |
| 490 | + param(np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float64)[:, ::2], "float64", id="2d-strided-float64"), |
| 491 | + param(np.arange(20, dtype=np.int32).reshape(4, 5)[::2, ::2], "int32", id="2d-strided-2x2-int32"), |
| 492 | + # Scalar (0-D array) |
| 493 | + param(np.array(42, dtype=np.int32), "int32", id="scalar-int32"), |
| 494 | + param(np.array(3.14, dtype=np.float64), "float64", id="scalar-float64"), |
| 495 | + # Empty arrays |
| 496 | + param(np.array([], dtype=np.int32), "int32", id="empty-1d-int32"), |
| 497 | + param(np.empty((0, 3), dtype=np.float64), "float64", id="empty-2d-float64"), |
| 498 | + # Single element |
| 499 | + param(np.array([1], dtype=np.int32), "int32", id="single-element"), |
| 500 | + # Structured dtype |
| 501 | + param(np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")]), "V12", id="structured-dtype"), |
| 502 | + ], |
| 503 | +) |
| 504 | +def test_from_array_interface(x, expected_dtype): |
| 505 | + smv = StridedMemoryView.from_array_interface(x) |
| 506 | + assert smv.size == x.size |
| 507 | + assert smv.dtype == np.dtype(expected_dtype) |
| 508 | + assert smv.shape == x.shape |
| 509 | + assert smv.ptr == x.ctypes.data |
| 510 | + assert smv.is_device_accessible is False |
| 511 | + assert smv.exporting_obj is x |
| 512 | + assert smv.readonly is not x.flags.writeable |
| 513 | + # Check strides |
| 514 | + strides_in_counts = convert_strides_to_counts(x.strides, x.dtype.itemsize) |
| 515 | + assert (x.flags.c_contiguous and smv.strides is None) or smv.strides == strides_in_counts |
0 commit comments