diff --git a/src/etcetra/client.py b/src/etcetra/client.py index e63bc40..c975790 100644 --- a/src/etcetra/client.py +++ b/src/etcetra/client.py @@ -362,10 +362,7 @@ def get_prefix( encoding='utf-8', ): encoded_key = key.encode(encoding) - if key[-1] == '/' and len(key) >= 2: - range_end = encoded_key[:-2] + bytes([encoded_key[-2] + 1]) + b'/' - else: - range_end = encoded_key[:-1] + bytes([encoded_key[-1] + 1]) + range_end = increment_last_byte(encoded_key) return rpc_pb2.RangeRequest( key=encoded_key, range_end=range_end, @@ -408,10 +405,7 @@ def delete_prefix( ): # TODO: Implement prev_kv response encoded_key = key.encode(encoding) - if key[-1] == '/' and len(key) >= 2: - range_end = encoded_key[:-2] + bytes([encoded_key[-2] + 1]) + b'/' - else: - range_end = encoded_key[:-1] + bytes([encoded_key[-1] + 1]) + range_end = increment_last_byte(encoded_key) return rpc_pb2.DeleteRangeRequest( key=encoded_key, range_end=range_end, @@ -461,10 +455,7 @@ def keys_prefix( encoding='utf-8', ): encoded_key = key.encode(encoding) - if key[-1] == '/' and len(key) >= 2: - range_end = encoded_key[:-2] + bytes([encoded_key[-2] + 1]) + b'/' - else: - range_end = encoded_key[:-1] + bytes([encoded_key[-1] + 1]) + range_end = increment_last_byte(encoded_key) return rpc_pb2.RangeRequest( key=encoded_key, range_end=range_end, @@ -1309,10 +1300,7 @@ def watch_prefix( encoding = self.encoding encoded_key = key.encode(encoding) - if key[-1] == '/' and len(key) >= 2: - range_end = encoded_key[:-2] + bytes([encoded_key[-2] + 1]) + b'/' - else: - range_end = encoded_key[:-1] + bytes([encoded_key[-1] + 1]) + range_end = increment_last_byte(encoded_key) return self._watch_impl( key.encode(encoding), encoding, ready_event=ready_event, filters=filters, prev_kv=prev_kv, @@ -1671,3 +1659,13 @@ async def __aexit__(self, exc_type, exc, tb) -> Optional[bool]: self._lock_id = None self._lease_id = None return False + + +def increment_last_byte(encoded_key): + s = bytearray(encoded_key) + for i in range(len(s) - 1, -1, -1): + if s[i] < 0xff: + s[i] += 1 + return bytes(s[:i + 1]) + else: + return b'\x00'