From 545b4c77869d5e0aa5a36b1c3f6c8c053865da64 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Sun, 25 Jan 2026 09:37:16 -0800 Subject: [PATCH] Add __eq__ and __hash__ to Buffer, LaunchConfig, Kernel, ObjectCode Make these classes hashable and comparable: - Buffer: identity based on (type, ptr, size) - LaunchConfig: uses _LAUNCH_CONFIG_ATTRS tuple for forward-compatible identity; also updates __repr__ to use the same attribute list - Kernel: identity based on (type, handle) - ObjectCode: identity based on (type, handle), triggers lazy load Stream, Event, Context, Device already had __eq__/__hash__. --- cuda_core/cuda/core/_launch_config.pyx | 19 ++++++++++++++++--- cuda_core/cuda/core/_memory/_buffer.pyx | 10 ++++++++++ cuda_core/cuda/core/_module.py | 18 ++++++++++++++++++ 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/_launch_config.pyx b/cuda_core/cuda/core/_launch_config.pyx index 032c40bd78..f59fea0716 100644 --- a/cuda_core/cuda/core/_launch_config.pyx +++ b/cuda_core/cuda/core/_launch_config.pyx @@ -21,6 +21,9 @@ cdef bint _inited = False cdef bint _use_ex = False cdef object _lock = threading.Lock() +# Attribute names for identity comparison and representation +_LAUNCH_CONFIG_ATTRS = ('grid', 'cluster', 'block', 'shmem_size', 'cooperative_launch') + cdef int _lazy_init() except?-1: global _inited, _use_ex @@ -131,11 +134,21 @@ cdef class LaunchConfig: if self.cooperative_launch and not Device().properties.cooperative_launch: raise CUDAError("cooperative kernels are not supported on this device") + def _identity(self): + return tuple(getattr(self, attr) for attr in _LAUNCH_CONFIG_ATTRS) + def __repr__(self): """Return string representation of LaunchConfig.""" - return (f"LaunchConfig(grid={self.grid}, cluster={self.cluster}, " - f"block={self.block}, shmem_size={self.shmem_size}, " - f"cooperative_launch={self.cooperative_launch})") + parts = ', '.join(f'{attr}={getattr(self, attr)!r}' for attr in _LAUNCH_CONFIG_ATTRS) + return f"LaunchConfig({parts})" + + def __eq__(self, other) -> bool: + if not isinstance(other, LaunchConfig): + return NotImplemented + return self._identity() == (other)._identity() + + def __hash__(self) -> int: + return hash((type(self),) + self._identity()) cdef cydriver.CUlaunchConfig _to_native_launch_config(self): _lazy_init() diff --git a/cuda_core/cuda/core/_memory/_buffer.pyx b/cuda_core/cuda/core/_memory/_buffer.pyx index 1df3841b4a..f31e6c99a3 100644 --- a/cuda_core/cuda/core/_memory/_buffer.pyx +++ b/cuda_core/cuda/core/_memory/_buffer.pyx @@ -324,6 +324,16 @@ cdef class Buffer: # that expect a raw pointer value return as_intptr(self._h_ptr) + def __eq__(self, other) -> bool: + if not isinstance(other, Buffer): + return NotImplemented + cdef Buffer other_buf = other + return (as_intptr(self._h_ptr) == as_intptr(other_buf._h_ptr) and + self._size == other_buf._size) + + def __hash__(self) -> int: + return hash((type(self), as_intptr(self._h_ptr), self._size)) + @property def is_device_accessible(self) -> bool: """Return True if this buffer can be accessed by the GPU, otherwise False.""" diff --git a/cuda_core/cuda/core/_module.py b/cuda_core/cuda/core/_module.py index 6abb7dfd31..e9bbcc3d06 100644 --- a/cuda_core/cuda/core/_module.py +++ b/cuda_core/cuda/core/_module.py @@ -528,6 +528,14 @@ def from_handle(handle: int, mod: ObjectCode = None) -> Kernel: return Kernel._from_obj(kernel_obj, mod) + def __eq__(self, other) -> bool: + if not isinstance(other, Kernel): + return NotImplemented + return int(self._handle) == int(other._handle) + + def __hash__(self) -> int: + return hash((type(self), int(self._handle))) + CodeTypeT = bytes | bytearray | str @@ -757,3 +765,13 @@ def handle(self): handle, call ``int(ObjectCode.handle)``. """ return self._handle + + def __eq__(self, other) -> bool: + if not isinstance(other, ObjectCode): + return NotImplemented + # Trigger lazy load for both objects to compare handles + return int(self.handle) == int(other.handle) + + def __hash__(self) -> int: + # Trigger lazy load to get the handle + return hash((type(self), int(self.handle)))