From 7e3574b89bd546b838b52dd6a73aff864691171b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 14:38:03 -0500 Subject: [PATCH 1/8] feat: relax the power of two check in StridedLayout (#1427) (#1471) Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- cuda_core/cuda/core/_layout.pxd | 14 ++++---------- cuda_core/cuda/core/_layout.pyx | 29 ++++++++++++++--------------- cuda_core/tests/test_utils.py | 13 +++++++++++++ 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/cuda_core/cuda/core/_layout.pxd b/cuda_core/cuda/core/_layout.pxd index 918a104f2f..8e2ead82a6 100644 --- a/cuda_core/cuda/core/_layout.pxd +++ b/cuda_core/cuda/core/_layout.pxd @@ -111,7 +111,8 @@ cdef class _StridedLayout: # ============================== cdef inline int _init(_StridedLayout self, BaseLayout& base, int itemsize, bint divide_strides=False) except -1 nogil: - _validate_itemsize(itemsize) + if itemsize <= 0: + raise ValueError("itemsize must be positive") if base.strides != NULL and divide_strides: _divide_strides(base, itemsize) @@ -123,7 +124,8 @@ cdef class _StridedLayout: return 0 cdef inline stride_t _init_dense(_StridedLayout self, BaseLayout& base, int itemsize, OrderFlag order_flag, axis_vec_t* stride_order=NULL) except -1 nogil: - _validate_itemsize(itemsize) + if itemsize <= 0: + raise ValueError("itemsize must be positive") cdef stride_t volume if order_flag == ORDER_C: @@ -643,14 +645,6 @@ cdef inline bint _normalize_axis(integer_t& axis, integer_t extent) except -1 no return True -cdef inline int _validate_itemsize(int itemsize) except -1 nogil: - if itemsize <= 0: - raise ValueError("itemsize must be positive") - if itemsize & (itemsize - 1): - raise ValueError("itemsize must be a power of two") - return 0 - - cdef inline bint _is_unique(BaseLayout& base, axis_vec_t& stride_order) except -1 nogil: if base.strides == NULL: return True diff --git a/cuda_core/cuda/core/_layout.pyx b/cuda_core/cuda/core/_layout.pyx index b1ff975dc9..3c9392430b 100644 --- a/cuda_core/cuda/core/_layout.pyx +++ b/cuda_core/cuda/core/_layout.pyx @@ -29,7 +29,7 @@ cdef class _StridedLayout: Otherwise, the strides are assumed to be implicitly C-contiguous and the resulting layout's :attr:`strides` will be None. itemsize : int - The number of bytes per single element (dtype size). Must be a power of two. + The number of bytes per single element (dtype size). divide_strides : bool, optional If True, the provided :attr:`strides` will be divided by the :attr:`itemsize`. @@ -40,7 +40,7 @@ cdef class _StridedLayout: Attributes ---------- itemsize : int - The number of bytes per single element (dtype size). Must be a power of two. + The number of bytes per single element (dtype size). slice_offset : int The offset (as a number of elements, not bytes) of the element at index ``(0,) * ndim``. See also :attr:`slice_offset_in_bytes`. @@ -636,7 +636,6 @@ cdef class _StridedLayout: In either case, the ``volume * itemsize`` of the layout remains the same. The conversion is subject to the following constraints: - * The old and new itemsizes must be powers of two. * The extent at ``axis`` must be a positive integer. * The stride at ``axis`` must be 1. @@ -1214,10 +1213,10 @@ cdef inline int64_t gcd(int64_t a, int64_t b) except? -1 nogil: cdef inline int pack_extents(BaseLayout& out_layout, stride_t& out_slice_offset, BaseLayout& in_layout, stride_t slice_offset, int itemsize, int new_itemsize, intptr_t data_ptr, bint keep_dim, int axis) except -1 nogil: cdef int ndim = in_layout.ndim - if new_itemsize <= 0 or new_itemsize & (new_itemsize - 1): - raise ValueError(f"new itemsize must be a power of two, got {new_itemsize}.") - if itemsize <= 0 or itemsize & (itemsize - 1): - raise ValueError(f"itemsize must be a power of two, got {itemsize}.") + if new_itemsize <= 0: + raise ValueError(f"new itemsize must be greater than zero, got {new_itemsize}.") + if itemsize <= 0: + raise ValueError(f"itemsize must be greater than zero, got {itemsize}.") if new_itemsize <= itemsize: if new_itemsize == itemsize: return 1 @@ -1270,10 +1269,10 @@ cdef inline int unpack_extents(BaseLayout &out_layout, BaseLayout &in_layout, in cdef int ndim = in_layout.ndim if not _normalize_axis(axis, ndim): raise ValueError(f"Invalid axis: {axis} out of range for {ndim}D tensor") - if new_itemsize <= 0 or new_itemsize & (new_itemsize - 1): - raise ValueError(f"new itemsize must be a power of two, got {new_itemsize}.") - if itemsize <= 0 or itemsize & (itemsize - 1): - raise ValueError(f"itemsize must be a power of two, got {itemsize}.") + if new_itemsize <= 0: + raise ValueError(f"new itemsize must be greater than zero, got {new_itemsize}.") + if itemsize <= 0: + raise ValueError(f"itemsize must be greater than zero, got {itemsize}.") if new_itemsize >= itemsize: if new_itemsize == itemsize: return 1 @@ -1301,10 +1300,10 @@ cdef inline int unpack_extents(BaseLayout &out_layout, BaseLayout &in_layout, in cdef inline int max_compatible_itemsize(BaseLayout& layout, stride_t slice_offset, int itemsize, int max_itemsize, intptr_t data_ptr, int axis) except? -1 nogil: cdef int ndim = layout.ndim - if max_itemsize <= 0 or max_itemsize & (max_itemsize - 1): - raise ValueError(f"max_itemsize must be a power of two, got {max_itemsize}.") - if itemsize <= 0 or itemsize & (itemsize - 1): - raise ValueError(f"itemsize must be a power of two, got {itemsize}.") + if max_itemsize <= 0: + raise ValueError(f"max_itemsize must be greater than zero, got {max_itemsize}.") + if itemsize <= 0: + raise ValueError(f"itemsize must be greater than zero, got {itemsize}.") if not _normalize_axis(axis, ndim): raise ValueError(f"Invalid axis: {axis} out of range for {ndim}D tensor") if max_itemsize < itemsize: diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index c778a9e493..06ee3520e2 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -433,3 +433,16 @@ def test_view_zero_size_array(api, shape, dtype): assert smv.size == 0 assert smv.shape == shape assert smv.dtype == np.dtype(dtype) + + +def test_from_buffer_with_non_power_of_two_itemsize(): + dev = Device() + dev.set_current() + dtype = np.dtype([("a", "int32"), ("b", "int8")]) + shape = (1,) + layout = _StridedLayout(shape=shape, strides=None, itemsize=dtype.itemsize) + required_size = layout.required_size_in_bytes() + assert required_size == math.prod(shape) * dtype.itemsize + buffer = dev.memory_resource.allocate(required_size) + view = StridedMemoryView.from_buffer(buffer, shape=shape, strides=layout.strides, dtype=dtype, is_readonly=True) + assert view.dtype == dtype From b4e3d1ab252da85f308f3167bf7828a66ed9ba3b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 14:39:16 -0500 Subject: [PATCH 2/8] feat: improve structured dtype array support in `StridedMemoryView` (#1425) (#1472) Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- cuda_core/cuda/core/_memoryview.pyx | 32 +++++++++-------------------- cuda_core/tests/test_utils.py | 17 +++++++++++++++ 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index c12cbbaa8a..1dbebe7789 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -365,8 +365,7 @@ cdef class StridedMemoryView: if self.dl_tensor != NULL: self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype) elif self.metadata is not None: - # TODO: this only works for built-in numeric types - self._dtype = _typestr2dtype[self.metadata["typestr"]] + self._dtype = _typestr2dtype(self.metadata["typestr"]) return self._dtype @@ -486,25 +485,14 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None): return buf -_builtin_numeric_dtypes = [ - numpy.dtype("uint8"), - numpy.dtype("uint16"), - numpy.dtype("uint32"), - numpy.dtype("uint64"), - numpy.dtype("int8"), - numpy.dtype("int16"), - numpy.dtype("int32"), - numpy.dtype("int64"), - numpy.dtype("float16"), - numpy.dtype("float32"), - numpy.dtype("float64"), - numpy.dtype("complex64"), - numpy.dtype("complex128"), - numpy.dtype("bool"), -] -# Doing it once to avoid repeated overhead -_typestr2dtype = {dtype.str: dtype for dtype in _builtin_numeric_dtypes} -_typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _builtin_numeric_dtypes} +@functools.lru_cache +def _typestr2dtype(str typestr): + return numpy.dtype(typestr) + + +@functools.lru_cache +def _typestr2itemsize(str typestr): + return _typestr2dtype(typestr).itemsize cdef object dtype_dlpack_to_numpy(DLDataType* dtype): @@ -664,7 +652,7 @@ cdef _StridedLayout layout_from_cai(object metadata): cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout) cdef object shape = metadata["shape"] cdef object strides = metadata.get("strides") - cdef int itemsize = _typestr2itemsize[metadata["typestr"]] + cdef int itemsize = _typestr2itemsize(metadata["typestr"]) layout.init_from_tuple(shape, strides, itemsize, True) return layout diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index 06ee3520e2..eb883cd3f3 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -446,3 +446,20 @@ def test_from_buffer_with_non_power_of_two_itemsize(): buffer = dev.memory_resource.allocate(required_size) view = StridedMemoryView.from_buffer(buffer, shape=shape, strides=layout.strides, dtype=dtype, is_readonly=True) assert view.dtype == dtype + + +def test_struct_array(): + cp = pytest.importorskip("cupy") + + x = np.array([(1.0, 2), (2.0, 3)], dtype=[("array1", np.float64), ("array2", np.int64)]) + + y = cp.empty(2, dtype=x.dtype) + y.set(x) + + smv = StridedMemoryView.from_cuda_array_interface(y, stream_ptr=0) + assert smv.size * smv.dtype.itemsize == x.nbytes + assert smv.size == x.size + assert smv.shape == x.shape + # full dtype information doesn't seem to be preserved due to use of type strings, + # which are lossy, e.g., dtype([("a", "int")]).str == "V8" + assert smv.dtype == np.dtype(f"V{x.itemsize}") From a442961d4311b5d74ae18cc419003430f62bc475 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 14:40:26 -0500 Subject: [PATCH 3/8] feat: allow constructing SMV from numpy arrays (#1428) (#1473) Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- cuda_core/cuda/core/_memoryview.pyx | 32 ++++++++++++++- cuda_core/tests/conftest.py | 2 +- cuda_core/tests/test_utils.py | 64 +++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index 1dbebe7789..ef85d132aa 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -139,7 +139,9 @@ cdef class StridedMemoryView: def from_dlpack(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView: cdef StridedMemoryView buf with warnings.catch_warnings(): - warnings.simplefilter("ignore") + # ignore the warning triggered by calling the constructor + # inside the library we're allowed to do this + warnings.simplefilter("ignore", DeprecationWarning) buf = cls() view_as_dlpack(obj, stream_ptr, buf) return buf @@ -148,11 +150,20 @@ cdef class StridedMemoryView: def from_cuda_array_interface(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView: cdef StridedMemoryView buf with warnings.catch_warnings(): - warnings.simplefilter("ignore") + warnings.simplefilter("ignore", DeprecationWarning) buf = cls() view_as_cai(obj, stream_ptr, buf) return buf + @classmethod + def from_array_interface(cls, obj: object) -> StridedMemoryView: + cdef StridedMemoryView buf + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + buf = cls() + view_as_array_interface(obj, buf) + return buf + @classmethod def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView: if check_has_dlpack(obj): @@ -597,6 +608,23 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None): return buf +cpdef StridedMemoryView view_as_array_interface(obj, view=None): + cdef dict data = obj.__array_interface__ + if data["version"] < 3: + raise BufferError("only NumPy Array Interface v3 or above is supported") + if data.get("mask") is not None: + raise BufferError("mask is not supported") + + cdef StridedMemoryView buf = StridedMemoryView() if view is None else view + buf.exporting_obj = obj + buf.metadata = data + buf.dl_tensor = NULL + buf.ptr, buf.readonly = data["data"] + buf.is_device_accessible = False + buf.device_id = handle_return(driver.cuCtxGetDevice()) + return buf + + def args_viewable_as_strided_memory(tuple arg_indices): """ Decorator to create proxy objects to :obj:`StridedMemoryView` for the diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index 0dac8f7def..340e632719 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -72,7 +72,7 @@ def init_cuda(): driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC) ) - yield + yield device _ = _device_unset_current() diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index eb883cd3f3..b185489df2 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -18,6 +18,7 @@ from cuda.core import Device from cuda.core._layout import _StridedLayout from cuda.core.utils import StridedMemoryView, args_viewable_as_strided_memory +from pytest import param def test_cast_to_3_tuple_success(): @@ -463,3 +464,66 @@ def test_struct_array(): # full dtype information doesn't seem to be preserved due to use of type strings, # which are lossy, e.g., dtype([("a", "int")]).str == "V8" assert smv.dtype == np.dtype(f"V{x.itemsize}") + + +@pytest.mark.parametrize( + ("x", "expected_dtype"), + [ + # 1D arrays with different dtypes + param(np.array([1, 2, 3], dtype=np.int32), "int32", id="1d-int32"), + param(np.array([1.0, 2.0, 3.0], dtype=np.float64), "float64", id="1d-float64"), + param(np.array([1 + 2j, 3 + 4j], dtype=np.complex128), "complex128", id="1d-complex128"), + param(np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64), "complex64", id="1d-complex64"), + param(np.array([1, 2, 3, 4, 5], dtype=np.uint8), "uint8", id="1d-uint8"), + param(np.array([1, 2], dtype=np.int64), "int64", id="1d-int64"), + param(np.array([100, 200, 300], dtype=np.int16), "int16", id="1d-int16"), + param(np.array([1000, 2000, 3000], dtype=np.uint16), "uint16", id="1d-uint16"), + param(np.array([10000, 20000, 30000], dtype=np.uint64), "uint64", id="1d-uint64"), + # 2D arrays - C-contiguous + param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32), "int32", id="2d-c-int32"), + param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), "float32", id="2d-c-float32"), + # 2D arrays - Fortran-contiguous + param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32, order="F"), "int32", id="2d-f-int32"), + param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64, order="F"), "float64", id="2d-f-float64"), + # 3D arrays + param(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int32), "int32", id="3d-int32"), + param(np.ones((2, 3, 4), dtype=np.float64), "float64", id="3d-float64"), + # Sliced/strided arrays + param(np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)[::2], "int32", id="1d-strided-int32"), + param(np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float64)[:, ::2], "float64", id="2d-strided-float64"), + param(np.arange(20, dtype=np.int32).reshape(4, 5)[::2, ::2], "int32", id="2d-strided-2x2-int32"), + # Scalar (0-D array) + param(np.array(42, dtype=np.int32), "int32", id="scalar-int32"), + param(np.array(3.14, dtype=np.float64), "float64", id="scalar-float64"), + # Empty arrays + param(np.array([], dtype=np.int32), "int32", id="empty-1d-int32"), + param(np.empty((0, 3), dtype=np.float64), "float64", id="empty-2d-float64"), + # Single element + param(np.array([1], dtype=np.int32), "int32", id="single-element"), + # Structured dtype + param(np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")]), "V12", id="structured-dtype"), + ], +) +def test_from_array_interface(x, init_cuda, expected_dtype): + smv = StridedMemoryView.from_array_interface(x) + assert smv.size == x.size + assert smv.dtype == np.dtype(expected_dtype) + assert smv.shape == x.shape + assert smv.ptr == x.ctypes.data + assert smv.device_id == init_cuda.device_id + assert smv.is_device_accessible is False + assert smv.exporting_obj is x + assert smv.readonly is not x.flags.writeable + # Check strides + strides_in_counts = convert_strides_to_counts(x.strides, x.dtype.itemsize) + assert (x.flags.c_contiguous and smv.strides is None) or smv.strides == strides_in_counts + + +def test_from_array_interface_unsupported_strides(init_cuda): + # Create an array with strides that aren't a multiple of itemsize + x = np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")]) + b = x["b"] + smv = StridedMemoryView.from_array_interface(b) + with pytest.raises(ValueError, match="strides must be divisible by itemsize"): + # TODO: ideally this would raise on construction + smv.strides # noqa: B018 From 8545d74e042d3468bc7cb0903b2bb5c479039deb Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 14:42:15 -0500 Subject: [PATCH 4/8] perf: remove warnings calls in smv constructor methods (#1431) (#1474) Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- cuda_core/cuda/core/_memoryview.pyx | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index ef85d132aa..eb2ccc13f7 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -137,30 +137,19 @@ cdef class StridedMemoryView: @classmethod def from_dlpack(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView: - cdef StridedMemoryView buf - with warnings.catch_warnings(): - # ignore the warning triggered by calling the constructor - # inside the library we're allowed to do this - warnings.simplefilter("ignore", DeprecationWarning) - buf = cls() + cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) view_as_dlpack(obj, stream_ptr, buf) return buf @classmethod def from_cuda_array_interface(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView: - cdef StridedMemoryView buf - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - buf = cls() + cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) view_as_cai(obj, stream_ptr, buf) return buf @classmethod def from_array_interface(cls, obj: object) -> StridedMemoryView: - cdef StridedMemoryView buf - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - buf = cls() + cdef StridedMemoryView buf = StridedMemoryView.__new__(cls) view_as_array_interface(obj, buf) return buf From 81c83f2a67579403d9a36b8345d9716d8f83e4f9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 13 Jan 2026 15:12:57 -0500 Subject: [PATCH 5/8] ci: enable running ci against release branches (#1475) --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8733498d91..844cac6f50 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,6 +15,7 @@ on: branches: - "pull-request/[0-9]+" - "main" + - "release/*" jobs: ci-vars: From e0a11efd6f16554b902c267516b820bb8ddde765 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 23:27:14 +0000 Subject: [PATCH 6/8] Fix #1417: Fix test for Numpy 2.4.0 (#1418) (#1476) Co-authored-by: Michael Droettboom Fix #1417: Fix test for Numpy 2.4.0 --- cuda_core/tests/test_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index b185489df2..a3c62a6aee 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -96,10 +96,7 @@ def _check_view(self, view, in_arr): assert view.shape == in_arr.shape assert view.size == in_arr.size strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize) - if in_arr.flags.c_contiguous: - assert view.strides is None - else: - assert view.strides == strides_in_counts + assert (in_arr.flags.c_contiguous and view.strides is None) or view.strides == strides_in_counts assert view.dtype == in_arr.dtype assert view.device_id == -1 assert view.is_device_accessible is False From 30a5cc01031da7cc5f0188984955313b71fae324 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 14 Jan 2026 11:28:13 -0500 Subject: [PATCH 7/8] chore: bump version to 0.5.1 --- cuda_core/cuda/core/_version.py | 2 +- cuda_core/docs/nv-versions.json | 4 ++++ cuda_core/pixi.lock | 18 +++--------------- cuda_core/pixi.toml | 2 +- 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/cuda_core/cuda/core/_version.py b/cuda_core/cuda/core/_version.py index 44683fadb3..6b9aee5f90 100644 --- a/cuda_core/cuda/core/_version.py +++ b/cuda_core/cuda/core/_version.py @@ -2,4 +2,4 @@ # # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.5.0" +__version__ = "0.5.1" diff --git a/cuda_core/docs/nv-versions.json b/cuda_core/docs/nv-versions.json index d0294bc722..8f5578e3fe 100644 --- a/cuda_core/docs/nv-versions.json +++ b/cuda_core/docs/nv-versions.json @@ -3,6 +3,10 @@ "version": "latest", "url": "https://nvidia.github.io/cuda-python/cuda-core/latest/" }, + { + "version": "0.5.1", + "url": "https://nvidia.github.io/cuda-python/cuda-core/0.5.1/" + }, { "version": "0.5.0", "url": "https://nvidia.github.io/cuda-python/cuda-core/0.5.0/" diff --git a/cuda_core/pixi.lock b/cuda_core/pixi.lock index 16a0d2460f..005fab59dd 100644 --- a/cuda_core/pixi.lock +++ b/cuda_core/pixi.lock @@ -1051,7 +1051,7 @@ packages: timestamp: 1764878612030 - conda: . name: cuda-core - version: 0.4.2 + version: 0.5.1 build: py314h59f3c06_0 subdir: linux-64 variants: @@ -1067,13 +1067,9 @@ packages: - python_abi 3.14.* *_cp314 - cuda-cudart >=13.1.80,<14.0a0 license: Apache-2.0 - input: - hash: cccb645b22f775570680f1a9a62e415a09774e46645523bbd147226681155628 - globs: - - pyproject.toml - conda: . name: cuda-core - version: 0.4.2 + version: 0.5.1 build: py314h625260f_0 subdir: win-64 variants: @@ -1087,13 +1083,9 @@ packages: - vc14_runtime >=14.16.27033 - python_abi 3.14.* *_cp314 license: Apache-2.0 - input: - hash: cccb645b22f775570680f1a9a62e415a09774e46645523bbd147226681155628 - globs: - - pyproject.toml - conda: . name: cuda-core - version: 0.4.2 + version: 0.5.1 build: py314ha479ada_0 subdir: linux-aarch64 variants: @@ -1109,10 +1101,6 @@ packages: - python_abi 3.14.* *_cp314 - cuda-cudart >=13.1.80,<14.0a0 license: Apache-2.0 - input: - hash: cccb645b22f775570680f1a9a62e415a09774e46645523bbd147226681155628 - globs: - - pyproject.toml - conda: https://conda.anaconda.org/conda-forge/noarch/cuda-crt-dev_linux-64-12.9.86-ha770c72_2.conda sha256: e6257534c4b4b6b8a1192f84191c34906ab9968c92680fa09f639e7846a87304 md5: 79d280de61e18010df5997daea4743df diff --git a/cuda_core/pixi.toml b/cuda_core/pixi.toml index 8683992cad..312262e8bc 100644 --- a/cuda_core/pixi.toml +++ b/cuda_core/pixi.toml @@ -68,7 +68,7 @@ cu12 = { features = ["cu12", "test", "cython-tests"], solve-group = "cu12" } # TODO: check if these can be extracted from pyproject.toml [package] name = "cuda-core" -version = "0.5.0" +version = "0.5.1" [package.build] backend = { name = "pixi-build-python", version = "*" } From bd8c21401a24b390003cefd17f30c4b00d49e889 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 14 Jan 2026 13:28:19 -0500 Subject: [PATCH 8/8] ci: trigger build