diff --git a/cuda_core/cuda/core/_context.pyx b/cuda_core/cuda/core/_context.pyx index cff79f62ad..61fd0f79d4 100644 --- a/cuda_core/cuda/core/_context.pyx +++ b/cuda_core/cuda/core/_context.pyx @@ -46,7 +46,7 @@ cdef class Context: return as_intptr(self._h_context) == as_intptr(_other._h_context) def __hash__(self) -> int: - return hash((type(self), as_intptr(self._h_context))) + return hash(as_intptr(self._h_context)) @dataclass diff --git a/cuda_core/cuda/core/_event.pyx b/cuda_core/cuda/core/_event.pyx index 951d7911b7..42916b257b 100644 --- a/cuda_core/cuda/core/_event.pyx +++ b/cuda_core/cuda/core/_event.pyx @@ -169,7 +169,7 @@ cdef class Event: raise RuntimeError(explanation) def __hash__(self) -> int: - return hash((type(self), as_intptr(self._h_event))) + return hash(as_intptr(self._h_event)) def __eq__(self, other) -> bool: # Note: using isinstance because `Event` can be subclassed. diff --git a/cuda_core/cuda/core/_launch_config.pyx b/cuda_core/cuda/core/_launch_config.pyx index 032c40bd78..8261b69df2 100644 --- a/cuda_core/cuda/core/_launch_config.pyx +++ b/cuda_core/cuda/core/_launch_config.pyx @@ -21,6 +21,9 @@ cdef bint _inited = False cdef bint _use_ex = False cdef object _lock = threading.Lock() +# Attribute names for identity comparison and representation +_LAUNCH_CONFIG_ATTRS = ('grid', 'cluster', 'block', 'shmem_size', 'cooperative_launch') + cdef int _lazy_init() except?-1: global _inited, _use_ex @@ -131,11 +134,21 @@ cdef class LaunchConfig: if self.cooperative_launch and not Device().properties.cooperative_launch: raise CUDAError("cooperative kernels are not supported on this device") + def _identity(self): + return tuple(getattr(self, attr) for attr in _LAUNCH_CONFIG_ATTRS) + def __repr__(self): """Return string representation of LaunchConfig.""" - return (f"LaunchConfig(grid={self.grid}, cluster={self.cluster}, " - f"block={self.block}, shmem_size={self.shmem_size}, " - f"cooperative_launch={self.cooperative_launch})") + parts = ', '.join(f'{attr}={getattr(self, attr)!r}' for attr in _LAUNCH_CONFIG_ATTRS) + return f"LaunchConfig({parts})" + + def __eq__(self, other) -> bool: + if not isinstance(other, LaunchConfig): + return NotImplemented + return self._identity() == (other)._identity() + + def __hash__(self) -> int: + return hash(self._identity()) cdef cydriver.CUlaunchConfig _to_native_launch_config(self): _lazy_init() diff --git a/cuda_core/cuda/core/_memory/_buffer.pyx b/cuda_core/cuda/core/_memory/_buffer.pyx index 1df3841b4a..57ec315201 100644 --- a/cuda_core/cuda/core/_memory/_buffer.pyx +++ b/cuda_core/cuda/core/_memory/_buffer.pyx @@ -324,6 +324,16 @@ cdef class Buffer: # that expect a raw pointer value return as_intptr(self._h_ptr) + def __eq__(self, other) -> bool: + if not isinstance(other, Buffer): + return NotImplemented + cdef Buffer other_buf = other + return (as_intptr(self._h_ptr) == as_intptr(other_buf._h_ptr) and + self._size == other_buf._size) + + def __hash__(self) -> int: + return hash((as_intptr(self._h_ptr), self._size)) + @property def is_device_accessible(self) -> bool: """Return True if this buffer can be accessed by the GPU, otherwise False.""" diff --git a/cuda_core/cuda/core/_module.py b/cuda_core/cuda/core/_module.py index 6abb7dfd31..e0648e151c 100644 --- a/cuda_core/cuda/core/_module.py +++ b/cuda_core/cuda/core/_module.py @@ -528,6 +528,14 @@ def from_handle(handle: int, mod: ObjectCode = None) -> Kernel: return Kernel._from_obj(kernel_obj, mod) + def __eq__(self, other) -> bool: + if not isinstance(other, Kernel): + return NotImplemented + return int(self._handle) == int(other._handle) + + def __hash__(self) -> int: + return hash(int(self._handle)) + CodeTypeT = bytes | bytearray | str @@ -757,3 +765,13 @@ def handle(self): handle, call ``int(ObjectCode.handle)``. """ return self._handle + + def __eq__(self, other) -> bool: + if not isinstance(other, ObjectCode): + return NotImplemented + # Trigger lazy load for both objects to compare handles + return int(self.handle) == int(other.handle) + + def __hash__(self) -> int: + # Trigger lazy load to get the handle + return hash(int(self.handle)) diff --git a/cuda_core/tests/test_hashable.py b/cuda_core/tests/test_hashable.py index feeae9b07b..89ebb33394 100644 --- a/cuda_core/tests/test_hashable.py +++ b/cuda_core/tests/test_hashable.py @@ -12,24 +12,109 @@ 5. Hash/equality contract compliance (if a == b, then hash(a) must equal hash(b)) """ -from cuda.core import Device +import pytest +from cuda.core import Device, LaunchConfig, Program from cuda.core._stream import Stream, StreamOptions +# ============================================================================ +# Fixtures for parameterized tests +# ============================================================================ + + +@pytest.fixture +def sample_device(init_cuda): + return Device() + + +@pytest.fixture +def sample_stream(sample_device): + return sample_device.create_stream() + + +@pytest.fixture +def sample_event(sample_device): + return sample_device.create_event() + + +@pytest.fixture +def sample_context(sample_device): + return sample_device.context + + +@pytest.fixture +def sample_buffer(sample_device): + return sample_device.allocate(1024) + + +@pytest.fixture +def sample_launch_config(): + return LaunchConfig(grid=(1,), block=(1,)) + + +@pytest.fixture +def sample_object_code(init_cuda): + prog = Program('extern "C" __global__ void test_kernel() {}', "c++") + return prog.compile("ptx") + + +@pytest.fixture +def sample_kernel(sample_object_code): + return sample_object_code.get_kernel("test_kernel") + + +# All hashable classes +HASHABLE = [ + "sample_device", + "sample_stream", + "sample_event", + "sample_context", + "sample_buffer", + "sample_launch_config", + "sample_object_code", + "sample_kernel", +] + + +# ============================================================================ +# Parameterized Hash Tests +# ============================================================================ + + +@pytest.mark.parametrize("fixture_name", HASHABLE) +def test_hash_consistency(fixture_name, request): + """Hash of same object is consistent across calls.""" + obj = request.getfixturevalue(fixture_name) + assert hash(obj) == hash(obj) + + +@pytest.mark.parametrize("fixture_name", HASHABLE) +def test_set_membership(fixture_name, request): + """Objects work correctly in sets.""" + obj = request.getfixturevalue(fixture_name) + s = {obj} + assert obj in s + assert len(s) == 1 + + +@pytest.mark.parametrize("fixture_name", HASHABLE) +def test_dict_key(fixture_name, request): + """Objects work correctly as dict keys.""" + obj = request.getfixturevalue(fixture_name) + d = {obj: "value"} + assert d[obj] == "value" + + # ============================================================================ # Integration Tests # ============================================================================ -def test_hash_type_disambiguation_and_mixed_dict(init_cuda): - """Test that hash salt (type(self)) prevents collisions between different types - and that different object types can coexist in dictionaries. +def test_mixed_type_dict(init_cuda): + """Test that different object types can coexist in dictionaries. - This test validates that: - 1. Including type(self) in the hash calculation ensures different types with - potentially similar underlying values (like monotonically increasing handles - or IDs) produce different hashes and don't collide. - 2. Different object types can be used together in the same dictionary without - conflicts. + Since each CUDA handle type has unique values within its type (handles are + memory addresses or unique identifiers), hash collisions between different + types are unlikely in practice. """ device = Device(0) device.set_current()