diff --git a/cuda_core/cuda/core/_context.pxd b/cuda_core/cuda/core/_context.pxd index dc853fc75d..9e1a460f50 100644 --- a/cuda_core/cuda/core/_context.pxd +++ b/cuda_core/cuda/core/_context.pxd @@ -14,6 +14,7 @@ cdef class Context: cdef: ContextHandle _h_context int _device_id + object __weakref__ @staticmethod cdef Context _from_handle(type cls, ContextHandle h_context, int device_id) diff --git a/cuda_core/cuda/core/_device.pyx b/cuda_core/cuda/core/_device.pyx index 672a9b0926..9d143679f8 100644 --- a/cuda_core/cuda/core/_device.pyx +++ b/cuda_core/cuda/core/_device.pyx @@ -955,7 +955,7 @@ class Device: Default value of `None` return the currently used device. """ - __slots__ = ("_device_id", "_memory_resource", "_has_inited", "_properties", "_uuid", "_context") + __slots__ = ("_device_id", "_memory_resource", "_has_inited", "_properties", "_uuid", "_context", "__weakref__") def __new__(cls, device_id: Device | int | None = None): if isinstance(device_id, Device): diff --git a/cuda_core/cuda/core/_event.pxd b/cuda_core/cuda/core/_event.pxd index e0e14e009f..c393b29ebf 100644 --- a/cuda_core/cuda/core/_event.pxd +++ b/cuda_core/cuda/core/_event.pxd @@ -16,6 +16,7 @@ cdef class Event: bint _ipc_enabled object _ipc_descriptor int _device_id + object __weakref__ @staticmethod cdef Event _init(type cls, int device_id, ContextHandle h_context, options, bint is_free) diff --git a/cuda_core/cuda/core/_launch_config.pxd b/cuda_core/cuda/core/_launch_config.pxd index eeb8ce41e7..909c236309 100644 --- a/cuda_core/cuda/core/_launch_config.pxd +++ b/cuda_core/cuda/core/_launch_config.pxd @@ -17,6 +17,7 @@ cdef class LaunchConfig: public bint cooperative_launch vector[cydriver.CUlaunchAttribute] _attrs + object __weakref__ cdef cydriver.CUlaunchConfig _to_native_launch_config(self) diff --git a/cuda_core/cuda/core/_memory/_buffer.pxd b/cuda_core/cuda/core/_memory/_buffer.pxd index 4238bd8d82..91c0cfe24a 100644 --- a/cuda_core/cuda/core/_memory/_buffer.pxd +++ b/cuda_core/cuda/core/_memory/_buffer.pxd @@ -23,6 +23,7 @@ cdef class Buffer: object _owner _MemAttrs _mem_attrs bint _mem_attrs_inited + object __weakref__ cdef class MemoryResource: diff --git a/cuda_core/cuda/core/_module.py b/cuda_core/cuda/core/_module.py index 6abb7dfd31..4fa3380cd5 100644 --- a/cuda_core/cuda/core/_module.py +++ b/cuda_core/cuda/core/_module.py @@ -546,7 +546,7 @@ class ObjectCode: :class:`~cuda.core.Program` """ - __slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map", "_name") + __slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map", "_name", "__weakref__") _supported_code_type = ("cubin", "ptx", "ltoir", "fatbin", "object", "library") def __new__(self, *args, **kwargs): diff --git a/cuda_core/cuda/core/_stream.pxd b/cuda_core/cuda/core/_stream.pxd index 69bd5821ad..c47ff1ea28 100644 --- a/cuda_core/cuda/core/_stream.pxd +++ b/cuda_core/cuda/core/_stream.pxd @@ -13,6 +13,7 @@ cdef class Stream: int _device_id int _nonblocking int _priority + object __weakref__ @staticmethod cdef Stream _from_handle(type cls, StreamHandle h_stream) diff --git a/cuda_core/tests/test_weakref.py b/cuda_core/tests/test_weakref.py new file mode 100644 index 0000000000..caf4239e67 --- /dev/null +++ b/cuda_core/tests/test_weakref.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import weakref + +import pytest +from cuda.core import Device + + +@pytest.fixture(scope="module") +def device(): + dev = Device() + dev.set_current() + return dev + + +@pytest.fixture +def stream(device): + return device.create_stream() + + +@pytest.fixture +def event(device): + return device.create_event() + + +@pytest.fixture +def context(device): + return device.context + + +@pytest.fixture +def buffer(device): + return device.allocate(1024) + + +@pytest.fixture +def launch_config(): + from cuda.core import LaunchConfig + + return LaunchConfig(grid=(1,), block=(1,)) + + +@pytest.fixture +def object_code(): + from cuda.core import Program + + prog = Program('extern "C" __global__ void test_kernel() {}', "c++") + return prog.compile("ptx") + + +@pytest.fixture +def kernel(object_code): + return object_code.get_kernel("test_kernel") + + +WEAK_REFERENCEABLE = [ + "device", + "stream", + "event", + "context", + "buffer", + "launch_config", + "object_code", + "kernel", +] + + +@pytest.mark.parametrize("fixture_name", WEAK_REFERENCEABLE) +def test_weakref(fixture_name, request): + """Core API classes should be weak-referenceable.""" + obj = request.getfixturevalue(fixture_name) + ref = weakref.ref(obj) + assert ref() is obj