diff --git a/src/packaging/version.py b/src/packaging/version.py index a11d4639..936dc7e6 100644 --- a/src/packaging/version.py +++ b/src/packaging/version.py @@ -548,6 +548,175 @@ def _key(self) -> CmpKey: ) return self._key_cache + # __hash__ must be defined when __eq__ is overridden, + # otherwise Python sets __hash__ to None. + def __hash__(self) -> int: + return hash(self._key) + + # Override comparison methods to use direct _key_cache access + # This is faster than property access, especially before Python 3.12 + def __lt__(self, other: _BaseVersion) -> bool: + if isinstance(other, Version): + if self._key_cache is None: + self._key_cache = _cmpkey( + self._epoch, + self._release, + self._pre, + self._post, + self._dev, + self._local, + ) + if other._key_cache is None: + other._key_cache = _cmpkey( + other._epoch, + other._release, + other._pre, + other._post, + other._dev, + other._local, + ) + return self._key_cache < other._key_cache + + if not isinstance(other, _BaseVersion): + return NotImplemented + + return super().__lt__(other) + + def __le__(self, other: _BaseVersion) -> bool: + if isinstance(other, Version): + if self._key_cache is None: + self._key_cache = _cmpkey( + self._epoch, + self._release, + self._pre, + self._post, + self._dev, + self._local, + ) + if other._key_cache is None: + other._key_cache = _cmpkey( + other._epoch, + other._release, + other._pre, + other._post, + other._dev, + other._local, + ) + return self._key_cache <= other._key_cache + + if not isinstance(other, _BaseVersion): + return NotImplemented + + return super().__le__(other) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Version): + if self._key_cache is None: + self._key_cache = _cmpkey( + self._epoch, + self._release, + self._pre, + self._post, + self._dev, + self._local, + ) + if other._key_cache is None: + other._key_cache = _cmpkey( + other._epoch, + other._release, + other._pre, + other._post, + other._dev, + other._local, + ) + return self._key_cache == other._key_cache + + if not isinstance(other, _BaseVersion): + return NotImplemented + + return super().__eq__(other) + + def __ge__(self, other: _BaseVersion) -> bool: + if isinstance(other, Version): + if self._key_cache is None: + self._key_cache = _cmpkey( + self._epoch, + self._release, + self._pre, + self._post, + self._dev, + self._local, + ) + if other._key_cache is None: + other._key_cache = _cmpkey( + other._epoch, + other._release, + other._pre, + other._post, + other._dev, + other._local, + ) + return self._key_cache >= other._key_cache + + if not isinstance(other, _BaseVersion): + return NotImplemented + + return super().__ge__(other) + + def __gt__(self, other: _BaseVersion) -> bool: + if isinstance(other, Version): + if self._key_cache is None: + self._key_cache = _cmpkey( + self._epoch, + self._release, + self._pre, + self._post, + self._dev, + self._local, + ) + if other._key_cache is None: + other._key_cache = _cmpkey( + other._epoch, + other._release, + other._pre, + other._post, + other._dev, + other._local, + ) + return self._key_cache > other._key_cache + + if not isinstance(other, _BaseVersion): + return NotImplemented + + return super().__gt__(other) + + def __ne__(self, other: object) -> bool: + if isinstance(other, Version): + if self._key_cache is None: + self._key_cache = _cmpkey( + self._epoch, + self._release, + self._pre, + self._post, + self._dev, + self._local, + ) + if other._key_cache is None: + other._key_cache = _cmpkey( + other._epoch, + other._release, + other._pre, + other._post, + other._dev, + other._local, + ) + return self._key_cache != other._key_cache + + if not isinstance(other, _BaseVersion): + return NotImplemented + + return super().__ne__(other) + @property @_deprecated("Version._version is private and will be removed soon") def _version(self) -> _Version: diff --git a/tests/test_version.py b/tests/test_version.py index 68d5dbe4..95d0aa00 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -12,7 +12,13 @@ import pretend import pytest -from packaging.version import InvalidVersion, Version, _VersionReplace, parse +from packaging.version import ( + InvalidVersion, + Version, + _BaseVersion, + _VersionReplace, + parse, +) if typing.TYPE_CHECKING: from collections.abc import Callable @@ -103,6 +109,27 @@ def test_parse_raises() -> None: ] +# Simple _BaseVersion subclass for testing comparison with non-Version types +class SimpleVersion(_BaseVersion): + """A simple _BaseVersion subclass for testing cross-type comparisons.""" + + def __init__(self, key: typing.Any) -> None: # noqa: ANN401 + # If key is a string, parse it as a version to create a compatible key + if isinstance(key, str): + parsed = Version(key) + self._key_tuple = parsed._key + else: + self._key_tuple = key + + @property + def _key(self) -> typing.Any: # noqa: ANN401 + return self._key_tuple + + @_key.setter + def _key(self, value: typing.Any) -> None: # noqa: ANN401 + self._key_tuple = value + + class TestVersion: @pytest.mark.parametrize("version", VERSIONS) def test_valid_versions(self, version: str) -> None: @@ -800,6 +827,76 @@ def test_compare_other(self, op: str, expected: bool) -> None: assert getattr(operator, op)(Version("1"), other) is expected + @pytest.mark.parametrize( + "op", ["__lt__", "__le__", "__eq__", "__ge__", "__gt__", "__ne__"] + ) + def test_base_version_notimplemented_with_non_base_version(self, op: str) -> None: + """Test _BaseVersion returns NotImplemented with non-_BaseVersion.""" + v = SimpleVersion("1.0") + assert getattr(v, op)(1) is NotImplemented + + def test_base_version_hash(self) -> None: + """Test that _BaseVersion hash works""" + v = SimpleVersion("1.0") + assert isinstance(hash(v), int) + + def test_base_version_ne_with_base_version(self) -> None: + """Test _BaseVersion.__ne__ with another _BaseVersion.""" + v1 = SimpleVersion("1.0") + v2 = SimpleVersion("2.0") + assert v1 != v2 + + def test_version_compare_with_base_version_subclass(self) -> None: + """Test Version comparison with another _BaseVersion subclass""" + v1 = Version("1.0") + v2 = SimpleVersion("1.0") + + # All comparisons should work with compatible keys + assert v1 == v2 + assert v1 <= v2 + assert v1 >= v2 + assert v1 == v2 + assert not (v1 < v2) + assert not (v1 > v2) + + # Test with different versions to exercise != path + v3 = Version("1.0") + v4 = SimpleVersion("2.0") + assert v3 != v4 + + def test_version_ne_with_uncached_keys(self) -> None: + """Test Version.__ne__ populates cache when comparing with another Version""" + v1 = Version("1.0") + v2 = Version("2.0") + + # Test with both caches None + result = v1 != v2 + assert result is True + + # Test with v1 cached, v3 uncached + v3 = Version("1.5") + result = v1 != v3 + assert result is True + + # Test with v3 cached, v4 uncached (the reverse case) + v4 = Version("1.2") + result = v4 != v3 + assert result is True + + def test_version_le_with_uncached_keys(self) -> None: + """Test Version.__le__ populates cache when comparing with another Version""" + v1 = Version("1.0") + v2 = Version("2.0") + + # Test <= with both caches None + result = v1 <= v2 + assert result is True + + # Test with v1 cached (from above), v3 uncached + v3 = Version("1.5") + result = v1 <= v3 + assert result is True + def test_major_version(self) -> None: assert Version("2.1.0").major == 2