diff --git a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx index 09179ef288..bd1f364732 100644 --- a/cuda_core/cuda/core/experimental/_memory/_buffer.pyx +++ b/cuda_core/cuda/core/experimental/_memory/_buffer.pyx @@ -5,7 +5,8 @@ from __future__ import annotations cimport cython -from libc.stdint cimport uintptr_t +from libc.stdint cimport uint8_t, uint16_t, uint32_t, uintptr_t +from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release, Py_buffer, PyBUF_SIMPLE from cuda.bindings cimport cydriver from cuda.core.experimental._memory._device_memory_resource cimport DeviceMemoryResource @@ -232,65 +233,27 @@ cdef class Buffer: """ cdef Stream s_stream = Stream_accept(stream) - cdef unsigned char c_value8 - cdef unsigned short c_value16 - cdef unsigned int c_value32 - cdef size_t N - cdef size_t width - cdef unsigned int int_value - - # Get fill pattern from value + + # Handle int case: 1-byte fill with automatic overflow checking. if isinstance(value, int): - # We define the int input to mean a 1-byte pattern. - # Match int.to_bytes(1, "little") behavior: raise OverflowError if not in [0, 256). - if value < 0 or value >= 256: - raise OverflowError("int value must be in range [0, 256)") - width = 1 - int_value = value - else: - try: - mv = memoryview(value) - except TypeError: - raise TypeError( - f"value must be an int or support the buffer protocol, got {type(value).__name__}" - ) from None - width = mv.nbytes - - # Validate width early to avoid copying/processing large invalid inputs. - if width not in (1, 2, 4): - raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}") - - # Convert to a 1-D view of bytes. - # - # Note: NumPy scalar memoryviews are 0-D, and int.from_bytes(mv, ...) errors with - # "0-dim memory has no length". Casting to 'B' gives us a byte-addressable view. - try: - int_value = int.from_bytes(mv.cast("B"), "little") - except TypeError: - int_value = int.from_bytes(mv.tobytes(), "little") - - # Validate buffer size modulus. - cdef size_t buffer_size = self._size - if buffer_size % width != 0: - raise ValueError(f"buffer size ({buffer_size}) must be divisible by {width}") - - # Perform fill based on width - cdef cydriver.CUstream s = s_stream._handle - if width == 1: - c_value8 = int_value - N = buffer_size - with nogil: - HANDLE_RETURN(cydriver.cuMemsetD8Async(self._ptr, c_value8, N, s)) - elif width == 2: - c_value16 = int_value - N = buffer_size // 2 - with nogil: - HANDLE_RETURN(cydriver.cuMemsetD16Async(self._ptr, c_value16, N, s)) - else: # width == 4 - c_value32 = int_value - N = buffer_size // 4 - with nogil: - HANDLE_RETURN(cydriver.cuMemsetD32Async(self._ptr, c_value32, N, s)) + Buffer_fill_uint8(self, value, s_stream._handle) + return + + # Handle bytes case: direct pointer access without intermediate objects. + if isinstance(value, bytes): + Buffer_fill_from_ptr(self, value, len(value), s_stream._handle) + return + + # General buffer protocol path using C buffer API. + cdef Py_buffer buf + if PyObject_GetBuffer(value, &buf, PyBUF_SIMPLE) != 0: + raise TypeError( + f"value must be an int or support the buffer protocol, got {type(value).__name__}" + ) + try: + Buffer_fill_from_ptr(self, buf.buf, buf.len, s_stream._handle) + finally: + PyBuffer_Release(&buf) def __dlpack__( self, @@ -419,6 +382,36 @@ cdef inline void Buffer_close(Buffer self, stream): self._alloc_stream = None +cdef inline void Buffer_fill_uint8(Buffer self, uint8_t value, cydriver.CUstream s): + with nogil: + HANDLE_RETURN(cydriver.cuMemsetD8Async(self._ptr, value, self._size, s)) + + +cdef inline void Buffer_fill_from_ptr( + Buffer self, const char* ptr, size_t width, cydriver.CUstream s +) except *: + cdef size_t buffer_size = self._size + + if width == 1: + with nogil: + HANDLE_RETURN(cydriver.cuMemsetD8Async( + self._ptr, (ptr)[0], buffer_size, s)) + elif width == 2: + if buffer_size & 0x1: + raise ValueError(f"buffer size ({buffer_size}) must be divisible by 2") + with nogil: + HANDLE_RETURN(cydriver.cuMemsetD16Async( + self._ptr, (ptr)[0], buffer_size // 2, s)) + elif width == 4: + if buffer_size & 0x3: + raise ValueError(f"buffer size ({buffer_size}) must be divisible by 4") + with nogil: + HANDLE_RETURN(cydriver.cuMemsetD32Async( + self._ptr, (ptr)[0], buffer_size // 4, s)) + else: + raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}") + + cdef Buffer_init_mem_attrs(Buffer self): if not self._mem_attrs_inited: query_memory_attrs(self._mem_attrs, self._ptr)