@@ -226,75 +226,108 @@ def test_buffer_copy_from():
226226def buffer_fill (dummy_mr : MemoryResource , device : Device , check = False ):
227227 stream = device .create_stream ()
228228
229- # Test width=1 ( byte fill)
229+ # Test 1- byte fill (int in [0, 256) )
230230 buffer1 = dummy_mr .allocate (size = 1024 )
231- buffer1 .fill (0x42 , width = 1 , stream = stream )
231+ buffer1 .fill (0x42 , stream = stream )
232232 device .sync ()
233233
234234 if check :
235235 ptr = ctypes .cast (buffer1 .handle , ctypes .POINTER (ctypes .c_byte ))
236236 for i in range (10 ):
237237 assert ptr [i ] == 0x42
238238
239- # Test error: invalid width
240- for bad_width in [w for w in range ( - 10 , 10 ) if w not in ( 1 , 2 , 4 ) ]:
241- with pytest .raises (ValueError , match = "width must be 1, 2, or 4" ):
242- buffer1 .fill (0x42 , width = bad_width , stream = stream )
239+ # Test error: int value out of range (OverflowError)
240+ for bad_value in [- 42 , - 1 , 256 , 1000 ]:
241+ with pytest .raises (OverflowError ):
242+ buffer1 .fill (bad_value , stream = stream )
243243
244- # Test error: value out of range for width=1
245- for bad_value in [- 42 , - 1 , 256 ]:
246- with pytest .raises (ValueError , match = "value must be in range \\ [0, 255\\ ]" ):
247- buffer1 .fill (bad_value , width = 1 , stream = stream )
248-
249- # Test error: buffer size not divisible by width
250- for bad_size in [1025 , 1027 , 1029 , 1031 ]: # Not divisible by 2
251- buffer_err = dummy_mr .allocate (size = 1025 )
252- with pytest .raises (ValueError , match = "must be divisible" ):
253- buffer_err .fill (0x1234 , width = 2 , stream = stream )
254- buffer_err .close ()
244+ # Test error: invalid type (not int and not buffer-protocol)
245+ with pytest .raises (TypeError , match = "must be an int or support the buffer protocol" ):
246+ buffer1 .fill ("invalid" , stream = stream )
255247
256248 buffer1 .close ()
257249
258- # Test width=2 (16-bit fill)
259- buffer2 = dummy_mr .allocate (size = 1024 ) # Divisible by 2
260- buffer2 .fill (0x1234 , width = 2 , stream = stream )
250+ # Test 2-byte fill via numpy uint16
251+ if np is not None :
252+ buffer2 = dummy_mr .allocate (size = 1024 ) # Divisible by 2
253+ buffer2 .fill (np .uint16 (0x1234 ), stream = stream )
254+ device .sync ()
255+
256+ if check :
257+ ptr = ctypes .cast (buffer2 .handle , ctypes .POINTER (ctypes .c_uint16 ))
258+ for i in range (5 ):
259+ assert ptr [i ] == 0x1234
260+
261+ buffer2 .close ()
262+
263+ # Test 2-byte fill via raw bytes
264+ buffer2b = dummy_mr .allocate (size = 1024 )
265+ buffer2b .fill (b"\x34 \x12 " , stream = stream ) # 0x1234 in little-endian
261266 device .sync ()
262267
263268 if check :
264- ptr = ctypes .cast (buffer2 .handle , ctypes .POINTER (ctypes .c_uint16 ))
269+ ptr = ctypes .cast (buffer2b .handle , ctypes .POINTER (ctypes .c_uint16 ))
265270 for i in range (5 ):
266271 assert ptr [i ] == 0x1234
267272
268- # Test error: value out of range for width=2
269- for bad_value in [- 42 , - 1 , 65536 , 65537 , 100000 ]:
270- with pytest .raises (ValueError , match = "value must be in range \\ [0, 65535\\ ]" ):
271- buffer2 .fill (bad_value , width = 2 , stream = stream )
273+ # Test error: buffer size not divisible by 2
274+ buffer_err = dummy_mr .allocate (size = 1025 )
275+ with pytest .raises (ValueError , match = "must be divisible by 2" ):
276+ buffer_err .fill (b"\x12 \x34 " , stream = stream )
277+ buffer_err .close ()
272278
273- buffer2 .close ()
279+ buffer2b .close ()
280+
281+ # Test 4-byte fill via numpy uint32
282+ if np is not None :
283+ buffer4 = dummy_mr .allocate (size = 1024 ) # Divisible by 4
284+ buffer4 .fill (np .uint32 (0xDEADBEEF ), stream = stream )
285+ device .sync ()
286+
287+ if check :
288+ ptr = ctypes .cast (buffer4 .handle , ctypes .POINTER (ctypes .c_uint32 ))
289+ for i in range (5 ):
290+ assert ptr [i ] == 0xDEADBEEF
274291
275- # Test width=4 (32-bit fill)
276- buffer4 = dummy_mr .allocate (size = 1024 ) # Divisible by 4
277- buffer4 .fill (0xDEADBEEF , width = 4 , stream = stream )
292+ buffer4 .close ()
293+
294+ # Test 4-byte fill via raw bytes
295+ buffer4b = dummy_mr .allocate (size = 1024 )
296+ buffer4b .fill (b"\xef \xbe \xad \xde " , stream = stream ) # 0xDEADBEEF in little-endian
278297 device .sync ()
279298
280299 if check :
281- ptr = ctypes .cast (buffer4 .handle , ctypes .POINTER (ctypes .c_uint32 ))
300+ ptr = ctypes .cast (buffer4b .handle , ctypes .POINTER (ctypes .c_uint32 ))
282301 for i in range (5 ):
283302 assert ptr [i ] == 0xDEADBEEF
284303
285- # Test error: value out of range for width=4
286- for bad_value in [- 42 , - 1 , 4294967296 , 4294967297 , 5000000000 ]:
287- with pytest .raises (ValueError , match = "value must be in range \\ [0, 4294967295\\ ]" ):
288- buffer4 .fill (bad_value , width = 4 , stream = stream )
289-
290- # Test error: buffer size not divisible by width
291- for bad_size in [1025 , 1026 , 1027 , 1029 , 1030 , 1031 ]: # Not divisible by 4
304+ # Test error: buffer size not divisible by 4
305+ for bad_size in [1025 , 1026 , 1027 ]:
292306 buffer_err2 = dummy_mr .allocate (size = bad_size )
293- with pytest .raises (ValueError , match = "must be divisible" ):
294- buffer_err2 .fill (0xDEADBEEF , width = 4 , stream = stream )
307+ with pytest .raises (ValueError , match = "must be divisible by 4 " ):
308+ buffer_err2 .fill (b" \xde \xad \xbe \xef " , stream = stream )
295309 buffer_err2 .close ()
296310
297- buffer4 .close ()
311+ buffer4b .close ()
312+
313+ # Test error: invalid byte length (not 1, 2, or 4)
314+ buffer_err3 = dummy_mr .allocate (size = 1024 )
315+ with pytest .raises (ValueError , match = "value must be 1, 2, or 4 bytes, got 3" ):
316+ buffer_err3 .fill (b"\x01 \x02 \x03 " , stream = stream )
317+ buffer_err3 .close ()
318+
319+ # Test float32 fill via numpy
320+ if np is not None :
321+ buffer_float = dummy_mr .allocate (size = 1024 )
322+ buffer_float .fill (np .float32 (1.0 ), stream = stream )
323+ device .sync ()
324+
325+ if check :
326+ ptr = ctypes .cast (buffer_float .handle , ctypes .POINTER (ctypes .c_float ))
327+ for i in range (5 ):
328+ assert ptr [i ] == 1.0
329+
330+ buffer_float .close ()
298331
299332
300333def test_buffer_fill ():
0 commit comments