Skip to content

Commit 030e423

Browse files
committed
Implement Buffer.fill() redesign
Simplify the API by removing the explicit width parameter and inferring width from the value. Accepts int in [0,256) for 1-byte fills, or collections.abc.Buffer objects (1, 2, or 4 bytes) for multi-byte fills.
1 parent 8e63850 commit 030e423

File tree

3 files changed

+116
-60
lines changed

3 files changed

+116
-60
lines changed

cuda_core/cuda/core/experimental/_memory/_buffer.pyx

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@ from cuda.core.experimental._stream cimport Stream_accept, Stream
1515
from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN
1616

1717
import abc
18+
import sys
1819
from typing import TypeVar, Union
1920

21+
if sys.version_info >= (3, 12):
22+
from collections.abc import Buffer as BufferProtocol
23+
else:
24+
BufferProtocol = object
25+
2026
from cuda.core.experimental._dlpack import DLDeviceType, make_py_capsule
2127
from cuda.core.experimental._utils.cuda_utils import driver
2228
from cuda.core.experimental._device import Device
@@ -203,58 +209,75 @@ cdef class Buffer:
203209
s
204210
))
205211

206-
def fill(self, value: int, width: int, *, stream: Stream | GraphBuilder):
207-
"""Fill this buffer with a value pattern asynchronously on the given stream.
212+
def fill(self, value: int | BufferProtocol, *, stream: Stream | GraphBuilder):
213+
"""Fill this buffer with a repeating byte pattern.
208214
209215
Parameters
210216
----------
211-
value : int
212-
Integer value to fill the buffer with
213-
width : int
214-
Width in bytes for each element (must be 1, 2, or 4)
217+
value : int | :obj:`collections.abc.Buffer`
218+
- int: Must be in range [0, 256). Converted to 1 byte.
219+
- :obj:`collections.abc.Buffer`: Must be 1, 2, or 4 bytes.
215220
stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
216-
Keyword argument specifying the stream for the asynchronous fill
221+
Stream for the asynchronous fill operation.
217222
218223
Raises
219224
------
225+
TypeError
226+
If value is not an int and does not support the buffer protocol.
220227
ValueError
221-
If width is not 1, 2, or 4, if value is out of range for the width,
222-
or if buffer size is not divisible by width
228+
If value byte length is not 1, 2, or 4.
229+
If buffer size is not divisible by value byte length.
230+
OverflowError
231+
If int value is outside [0, 256).
223232
224233
"""
225234
cdef Stream s_stream = Stream_accept(stream)
226235
cdef unsigned char c_value8
227236
cdef unsigned short c_value16
228237
cdef unsigned int c_value32
229238
cdef size_t N
239+
cdef size_t width
240+
cdef bytes pattern
241+
242+
# Get fill pattern from value
243+
if isinstance(value, int):
244+
# int.to_bytes raises OverflowError if not in [0, 256)
245+
pattern = value.to_bytes(1, "little")
246+
else:
247+
try:
248+
mv = memoryview(value)
249+
except TypeError:
250+
raise TypeError(
251+
f"value must be an int or support the buffer protocol, got {type(value).__name__}"
252+
) from None
253+
pattern = mv.tobytes()
254+
255+
width = len(pattern)
230256

231257
# Validate width
232258
if width not in (1, 2, 4):
233-
raise ValueError(f"width must be 1, 2, or 4, got {width}")
259+
raise ValueError(f"value must be 1, 2, or 4 bytes, got {width}")
234260

235261
# Validate buffer size modulus.
236262
cdef size_t buffer_size = self._size
237263
if buffer_size % width != 0:
238-
raise ValueError(f"buffer size ({buffer_size}) must be divisible by width ({width})")
239-
240-
# Map width (bytes) to bitwidth and validate value
241-
cdef int bitwidth = width * 8
242-
_validate_value_against_bitwidth(bitwidth, value, is_signed=False)
264+
raise ValueError(f"buffer size ({buffer_size}) must be divisible by {width}")
243265

244-
# Validate value fits in width and perform fill
266+
# Perform fill based on width
245267
cdef cydriver.CUstream s = s_stream._handle
268+
int_value = int.from_bytes(pattern, "little")
246269
if width == 1:
247-
c_value8 = <unsigned char>value
270+
c_value8 = int_value
248271
N = buffer_size
249272
with nogil:
250273
HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, c_value8, N, s))
251274
elif width == 2:
252-
c_value16 = <unsigned short>value
275+
c_value16 = int_value
253276
N = buffer_size // 2
254277
with nogil:
255278
HANDLE_RETURN(cydriver.cuMemsetD16Async(<cydriver.CUdeviceptr>self._ptr, c_value16, N, s))
256279
else: # width == 4
257-
c_value32 = <unsigned int>value
280+
c_value32 = int_value
258281
N = buffer_size // 4
259282
with nogil:
260283
HANDLE_RETURN(cydriver.cuMemsetD32Async(<cydriver.CUdeviceptr>self._ptr, c_value32, N, s))

cuda_core/tests/test_graph_mem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def apply_kernels(mr, stream, out):
112112
# Fills out with 3
113113
def apply_kernels(mr, stream, out):
114114
buffer = mr.allocate(NBYTES, stream=stream)
115-
buffer.fill(3, width=1, stream=stream)
115+
buffer.fill(3, stream=stream)
116116
out.copy_from(buffer, stream=stream)
117117
buffer.close()
118118

cuda_core/tests/test_memory.py

Lines changed: 73 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -226,75 +226,108 @@ def test_buffer_copy_from():
226226
def buffer_fill(dummy_mr: MemoryResource, device: Device, check=False):
227227
stream = device.create_stream()
228228

229-
# Test width=1 (byte fill)
229+
# Test 1-byte fill (int in [0, 256))
230230
buffer1 = dummy_mr.allocate(size=1024)
231-
buffer1.fill(0x42, width=1, stream=stream)
231+
buffer1.fill(0x42, stream=stream)
232232
device.sync()
233233

234234
if check:
235235
ptr = ctypes.cast(buffer1.handle, ctypes.POINTER(ctypes.c_byte))
236236
for i in range(10):
237237
assert ptr[i] == 0x42
238238

239-
# Test error: invalid width
240-
for bad_width in [w for w in range(-10, 10) if w not in (1, 2, 4)]:
241-
with pytest.raises(ValueError, match="width must be 1, 2, or 4"):
242-
buffer1.fill(0x42, width=bad_width, stream=stream)
239+
# Test error: int value out of range (OverflowError)
240+
for bad_value in [-42, -1, 256, 1000]:
241+
with pytest.raises(OverflowError):
242+
buffer1.fill(bad_value, stream=stream)
243243

244-
# Test error: value out of range for width=1
245-
for bad_value in [-42, -1, 256]:
246-
with pytest.raises(ValueError, match="value must be in range \\[0, 255\\]"):
247-
buffer1.fill(bad_value, width=1, stream=stream)
248-
249-
# Test error: buffer size not divisible by width
250-
for bad_size in [1025, 1027, 1029, 1031]: # Not divisible by 2
251-
buffer_err = dummy_mr.allocate(size=1025)
252-
with pytest.raises(ValueError, match="must be divisible"):
253-
buffer_err.fill(0x1234, width=2, stream=stream)
254-
buffer_err.close()
244+
# Test error: invalid type (not int and not buffer-protocol)
245+
with pytest.raises(TypeError, match="must be an int or support the buffer protocol"):
246+
buffer1.fill("invalid", stream=stream)
255247

256248
buffer1.close()
257249

258-
# Test width=2 (16-bit fill)
259-
buffer2 = dummy_mr.allocate(size=1024) # Divisible by 2
260-
buffer2.fill(0x1234, width=2, stream=stream)
250+
# Test 2-byte fill via numpy uint16
251+
if np is not None:
252+
buffer2 = dummy_mr.allocate(size=1024) # Divisible by 2
253+
buffer2.fill(np.uint16(0x1234), stream=stream)
254+
device.sync()
255+
256+
if check:
257+
ptr = ctypes.cast(buffer2.handle, ctypes.POINTER(ctypes.c_uint16))
258+
for i in range(5):
259+
assert ptr[i] == 0x1234
260+
261+
buffer2.close()
262+
263+
# Test 2-byte fill via raw bytes
264+
buffer2b = dummy_mr.allocate(size=1024)
265+
buffer2b.fill(b"\x34\x12", stream=stream) # 0x1234 in little-endian
261266
device.sync()
262267

263268
if check:
264-
ptr = ctypes.cast(buffer2.handle, ctypes.POINTER(ctypes.c_uint16))
269+
ptr = ctypes.cast(buffer2b.handle, ctypes.POINTER(ctypes.c_uint16))
265270
for i in range(5):
266271
assert ptr[i] == 0x1234
267272

268-
# Test error: value out of range for width=2
269-
for bad_value in [-42, -1, 65536, 65537, 100000]:
270-
with pytest.raises(ValueError, match="value must be in range \\[0, 65535\\]"):
271-
buffer2.fill(bad_value, width=2, stream=stream)
273+
# Test error: buffer size not divisible by 2
274+
buffer_err = dummy_mr.allocate(size=1025)
275+
with pytest.raises(ValueError, match="must be divisible by 2"):
276+
buffer_err.fill(b"\x12\x34", stream=stream)
277+
buffer_err.close()
272278

273-
buffer2.close()
279+
buffer2b.close()
280+
281+
# Test 4-byte fill via numpy uint32
282+
if np is not None:
283+
buffer4 = dummy_mr.allocate(size=1024) # Divisible by 4
284+
buffer4.fill(np.uint32(0xDEADBEEF), stream=stream)
285+
device.sync()
286+
287+
if check:
288+
ptr = ctypes.cast(buffer4.handle, ctypes.POINTER(ctypes.c_uint32))
289+
for i in range(5):
290+
assert ptr[i] == 0xDEADBEEF
274291

275-
# Test width=4 (32-bit fill)
276-
buffer4 = dummy_mr.allocate(size=1024) # Divisible by 4
277-
buffer4.fill(0xDEADBEEF, width=4, stream=stream)
292+
buffer4.close()
293+
294+
# Test 4-byte fill via raw bytes
295+
buffer4b = dummy_mr.allocate(size=1024)
296+
buffer4b.fill(b"\xef\xbe\xad\xde", stream=stream) # 0xDEADBEEF in little-endian
278297
device.sync()
279298

280299
if check:
281-
ptr = ctypes.cast(buffer4.handle, ctypes.POINTER(ctypes.c_uint32))
300+
ptr = ctypes.cast(buffer4b.handle, ctypes.POINTER(ctypes.c_uint32))
282301
for i in range(5):
283302
assert ptr[i] == 0xDEADBEEF
284303

285-
# Test error: value out of range for width=4
286-
for bad_value in [-42, -1, 4294967296, 4294967297, 5000000000]:
287-
with pytest.raises(ValueError, match="value must be in range \\[0, 4294967295\\]"):
288-
buffer4.fill(bad_value, width=4, stream=stream)
289-
290-
# Test error: buffer size not divisible by width
291-
for bad_size in [1025, 1026, 1027, 1029, 1030, 1031]: # Not divisible by 4
304+
# Test error: buffer size not divisible by 4
305+
for bad_size in [1025, 1026, 1027]:
292306
buffer_err2 = dummy_mr.allocate(size=bad_size)
293-
with pytest.raises(ValueError, match="must be divisible"):
294-
buffer_err2.fill(0xDEADBEEF, width=4, stream=stream)
307+
with pytest.raises(ValueError, match="must be divisible by 4"):
308+
buffer_err2.fill(b"\xde\xad\xbe\xef", stream=stream)
295309
buffer_err2.close()
296310

297-
buffer4.close()
311+
buffer4b.close()
312+
313+
# Test error: invalid byte length (not 1, 2, or 4)
314+
buffer_err3 = dummy_mr.allocate(size=1024)
315+
with pytest.raises(ValueError, match="value must be 1, 2, or 4 bytes, got 3"):
316+
buffer_err3.fill(b"\x01\x02\x03", stream=stream)
317+
buffer_err3.close()
318+
319+
# Test float32 fill via numpy
320+
if np is not None:
321+
buffer_float = dummy_mr.allocate(size=1024)
322+
buffer_float.fill(np.float32(1.0), stream=stream)
323+
device.sync()
324+
325+
if check:
326+
ptr = ctypes.cast(buffer_float.handle, ctypes.POINTER(ctypes.c_float))
327+
for i in range(5):
328+
assert ptr[i] == 1.0
329+
330+
buffer_float.close()
298331

299332

300333
def test_buffer_fill():

0 commit comments

Comments
 (0)