From 9d4a1007cbe5e1ed1392c57ab96c5a6466708246 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Wed, 15 Jan 2025 18:13:07 +0800 Subject: [PATCH] Ranging Service --- .vscode/settings.json | 1 + bumble/gatt.py | 10 + bumble/profiles/rap.py | 692 +++++++++++++++++++++++++++++++ examples/run_channel_sounding.py | 29 +- tests/rap_test.py | 376 +++++++++++++++++ 5 files changed, 1107 insertions(+), 1 deletion(-) create mode 100644 bumble/profiles/rap.py create mode 100644 tests/rap_test.py diff --git a/.vscode/settings.json b/.vscode/settings.json index e0ff04e1..65233bc1 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -80,6 +80,7 @@ "subband", "subbands", "subevent", + "subevents", "Subrating", "substates", "tobytes", diff --git a/bumble/gatt.py b/bumble/gatt.py index 96e836cf..6cd332ce 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -123,6 +123,8 @@ GATT_ELECTRONIC_SHELF_LABEL_SERVICE = UUID.from_16_bits(0X1857, 'Electronic Shelf Label') GATT_GAMING_AUDIO_SERVICE = UUID.from_16_bits(0x1858, 'Gaming Audio') GATT_MESH_PROXY_SOLICITATION_SERVICE = UUID.from_16_bits(0x1859, 'Mesh Audio Solicitation') +GATT_INDUSTRIAL_MEASUREMENT_DEVICE_SERVICE = UUID.from_16_bits(0x185A, 'Industrial Measurement Device Service') +GATT_RANGING_SERVICE = UUID.from_16_bits(0x185B, 'Ranging Service') # Attribute Types GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service') @@ -286,6 +288,14 @@ GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume') GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT') +# Ranging Service +GATT_RAS_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C14, "RAS Features") +GATT_REAL_TIME_RANGING_DATA_CHARACTERISTIC = UUID.from_16_bits(0x2C15, "Real-time Ranging Data") +GATT_ON_DEMAND_RANGING_DATA_CHARACTERISTIC = UUID.from_16_bits(0x2C16, "On-demand Ranging Data") +GATT_RAS_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2C17, "RAS Control Point") +GATT_RANGING_DATA_READY_CHARACTERISTIC = UUID.from_16_bits(0x2C18, "Ranging Data Ready") +GATT_RANGING_DATA_OVERWRITTEN_CHARACTERISTIC = UUID.from_16_bits(0x2C19, "Ranging Data Overwritten") + # Apple Notification Center Service GATT_ANCS_SERVICE = UUID('7905F431-B5CE-4E99-A40F-4B1E122D00D0', 'Apple Notification Center') GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC = UUID('9FBF120D-6301-42D9-8C58-25E699A21DBD', 'Notification Source') diff --git a/bumble/profiles/rap.py b/bumble/profiles/rap.py new file mode 100644 index 00000000..12bb4b10 --- /dev/null +++ b/bumble/profiles/rap.py @@ -0,0 +1,692 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations + +import asyncio +import dataclasses +import enum +import functools +import logging +import struct + +from typing_extensions import Self + +from bumble import att +from bumble import core +from bumble import device +from bumble import gatt +from bumble import gatt_client +from bumble import hci +from bumble import utils + +logger = logging.getLogger(__name__) + + +class RasFeatures(enum.IntFlag): + '''Ranging Service - 3.1.1 RAS Features format.''' + + REAL_TIME_RANGING_DATA = 0x01 + RETRIEVE_LOST_RANGING_DATA_SEGMENTS = 0x02 + ABORT_OPERATION = 0x04 + FILTER_RANGING_DATA = 0x08 + + +class RasControlPointOpCode(utils.OpenIntEnum): + '''Ranging Service - 3.3.1 RAS Control Point Op Codes and Parameters requirements.''' + + GET_RANGING_DATA = 0x00 + ACK_RANGING_DATA = 0x01 + RETRIEVE_LOST_RANGING_DATA_SEGMENTS = 0x02 + ABORT_OPERATION = 0x03 + SET_FILTER = 0x04 + + +class RasControlPointResponseOpCode(utils.OpenIntEnum): + '''Ranging Service - 3.3.1 RAS Control Point Op Codes and Parameters requirements.''' + + COMPLETE_RANGING_DATA_RESPONSE = 0x00 + COMPLETE_LOST_RANGING_DATA_RESPONSE = 0x01 + RESPONSE_CODE = 0x02 + + +class RasControlPointResponseCode(utils.OpenIntEnum): + '''Ranging Service - 3.3.1 RAS Control Point Op Codes and Parameters requirements.''' + + # RFU = 0x00 + + # Normal response for a successful operation + SUCCESS = 0x01 + # Normal response if an unsupported Op Code is received + OP_CODE_NOT_SUPPORTED = 0x02 + # Normal response if Parameter received does not meet the requirements of the + # service + INVALID_PARAMETER = 0x03 + # Normal response for a successful write operation where the values written to the + # RAS Control Point are being persisted. + SUCCESS_PERSISTED = 0x04 + # Normal response if a request for Abort is unsuccessful + ABORT_UNSUCCESSFUL = 0x05 + # Normal response if unable to complete a procedure for any reason + PROCEDURE_NOT_COMPLETED = 0x06 + # Normal response if the Server is still busy with other requests + SERVER_BUSY = 0x07 + # Normal response if the requested Ranging Counter is not found + NO_RECORDS_FOUND = 0x08 + + +@dataclasses.dataclass +class SegmentationHeader: + '''Ranging Service - 3.2.1.1 Segmentation Header.''' + + is_first: bool + is_last: bool + segment_index: int + + def __bytes__(self) -> bytes: + return bytes( + [ + ( + ((self.segment_index & 0x3F) << 2) + | (0x01 if self.is_first else 0x00) + | (0x02 if self.is_last else 0x00) + ) + ] + ) + + @classmethod + def from_bytes(cls: type[Self], data: bytes) -> Self: + return cls( + is_first=bool(data[0] & 0x01), + is_last=bool(data[0] & 0x02), + segment_index=data[0] >> 2, + ) + + +@dataclasses.dataclass +class RangingHeader: + '''Ranging Service - Table 3.7: Ranging Header structure.''' + + configuration_id: int + selected_tx_power: int + antenna_paths_mask: int + ranging_counter: int + + def __bytes__(self) -> bytes: + return struct.pack( + ' Self: + ranging_counter_and_configuration_id, selected_tx_power, antenna_paths_mask = ( + struct.unpack_from('> 12, + selected_tx_power=selected_tx_power, + antenna_paths_mask=antenna_paths_mask, + ) + + +@dataclasses.dataclass +class Step: + '''Ranging Service - Table 3.8: Subevent Header and Data structure.''' + + mode: int + data: bytes + + def __bytes__(self) -> bytes: + return bytes([self.mode]) + self.data + + @classmethod + def parse_from( + cls: type[Self], + data: bytes, + config: device.ChannelSoundingConfig, + num_antenna_paths: int, + offset: int = 0, + ) -> tuple[int, Self]: + mode = data[offset] + contain_sounding_sequence = config.rtt_type in ( + hci.RttType.SOUNDING_SEQUENCE_32_BIT, + hci.RttType.SOUNDING_SEQUENCE_96_BIT, + ) + is_initiator = config.role == hci.CsRole.INITIATOR + + # TODO: Parse mode/role-specific data. + if mode == 0: + length = 5 if is_initiator else 3 + elif mode == 1: + length = 12 if contain_sounding_sequence else 6 + elif mode == 2: + length = (num_antenna_paths + 1) * 4 + 1 + elif mode == 3: + length = (num_antenna_paths + 1) * 4 + ( + 13 if contain_sounding_sequence else 7 + ) + else: + raise core.InvalidPacketError(f"Unknown mode 0x{mode:02X}") + return (offset + length + 1), cls( + mode=mode, data=data[offset + 1 : offset + 1 + length] + ) + + +@dataclasses.dataclass +class Subevent: + '''Ranging Service - Table 3.8: Subevent Header and Data structure.''' + + start_acl_connection_event: int + frequency_compensation: int + ranging_done_status: int + subevent_done_status: int + ranging_abort_reason: int + subevent_abort_reason: int + reference_power_level: int + steps: list[Step] = dataclasses.field(default_factory=list) + + def __bytes__(self) -> bytes: + return struct.pack( + ' tuple[int, Self]: + ( + start_acl_connection_event, + frequency_compensation, + ranging_done_status_and_subevent_done_status, + ranging_abort_reason_and_subevent_abort_reason, + reference_power_level, + num_reported_steps, + ) = struct.unpack_from('> 4, + ranging_abort_reason=ranging_abort_reason_and_subevent_abort_reason & 0x0F, + subevent_abort_reason=ranging_abort_reason_and_subevent_abort_reason >> 4, + reference_power_level=reference_power_level, + steps=steps, + ) + + +@dataclasses.dataclass +class RangingData: + '''Ranging Service - 3.2.1 Ranging Data format.''' + + ranging_header: RangingHeader + subevents: list[Subevent] = dataclasses.field(default_factory=list) + + def __bytes__(self) -> bytes: + return bytes(self.ranging_header) + b''.join(map(bytes, self.subevents)) + + @classmethod + def from_bytes( + cls: type[Self], + data: bytes, + config: device.ChannelSoundingConfig, + ) -> Self: + pass + ranging_header = RangingHeader.from_bytes(data) + num_antenna_paths = 0 + antenna_path_mask = ranging_header.antenna_paths_mask + while antenna_path_mask > 0: + if antenna_path_mask & 0x01: + num_antenna_paths += 1 + antenna_path_mask >>= 1 + + subevents: list[Subevent] = [] + offset = 4 + while offset < len(data): + offset, subevent = Subevent.parse_from( + data=data, + config=config, + num_antenna_paths=num_antenna_paths, + offset=offset, + ) + subevents.append(subevent) + return cls(ranging_header=ranging_header, subevents=subevents) + + +class RangingService(gatt.TemplateService): + UUID = gatt.GATT_RANGING_SERVICE + + @dataclasses.dataclass + class Client: + active_mode: RangingService.Mode + cccd_value: gatt.ClientCharacteristicConfigurationBits + # procedure counter to ranging data + ranging_data_table: dict[int, RangingData] = dataclasses.field( + default_factory=dict + ) + # config id to procedure counter + active_procedure_counter: dict[int, int] = dataclasses.field( + default_factory=dict + ) + + class Mode(enum.IntEnum): + '''Bumble-defined mode enum.''' + + INACTIVE = 0 + ON_DEMAND = 1 + REAL_TIME = 2 + + clients: dict[device.Connection, Client] + real_time_ranging_data_characteristic: gatt.Characteristic[bytes] | None + + def __init__( + self, + device: device.Device, + ras_features: RasFeatures, + ) -> None: + self.clients = {} + self.device = device + self.device.host.on('cs_subevent_result', self._on_subevent_result) + self.device.host.on('cs_subevent_result_continue', self._post_subevent_result) + self.device.on('connection', self._on_connection) + self.ras_features_characteristic = gatt.Characteristic[bytes]( + gatt.GATT_RAS_FEATURES_CHARACTERISTIC, + properties=gatt.Characteristic.Properties.READ, + permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, + value=struct.pack(" None: + cccd_value = gatt.ClientCharacteristicConfigurationBits( + int.from_bytes(data, 'little') + ) + logger.debug("on_cccd_write, connection=%s, value=%s", connection, cccd_value) + if not (client := self.clients.get(connection)): + client = self.clients[connection] = self.Client( + active_mode=self.Mode.INACTIVE, + cccd_value=gatt.ClientCharacteristicConfigurationBits.DEFAULT, + ) + + if client.active_mode not in (self.Mode.INACTIVE, mode): + logger.error("Forbid subscribing when another mode is active!") + raise att.ATT_Error(att.ErrorCode.WRITE_REQUEST_REJECTED) + + if cccd_value == gatt.ClientCharacteristicConfigurationBits.DEFAULT: + client.active_mode = self.Mode.INACTIVE + else: + client.active_mode = mode + client.cccd_value = cccd_value + + def _on_cccd_read( + self, + mode: Mode, + connection: device.Connection, + ) -> bytes: + if not (client := self.clients.get(connection)): + client = self.clients[connection] = self.Client( + active_mode=self.Mode.INACTIVE, + cccd_value=gatt.ClientCharacteristicConfigurationBits.DEFAULT, + ) + if mode != client.active_mode: + client.cccd_value = gatt.ClientCharacteristicConfigurationBits.DEFAULT + return struct.pack(" None: + if not (connection := self.device.lookup_connection(event.connection_handle)): + logger.error( + "Subevent for unknown connection 0x%04X", event.connection_handle + ) + return + if not (client := self.clients[connection]): + return + procedure_counter = event.procedure_counter + if not (ranging_data := client.ranging_data_table.get(procedure_counter)): + ranging_data = client.ranging_data_table[procedure_counter] = RangingData( + ranging_header=RangingHeader( + event.config_id, + selected_tx_power=connection.cs_procedures[ + event.config_id + ].selected_tx_power, + antenna_paths_mask=(1 << (event.num_antenna_paths + 1)) - 1, + ranging_counter=procedure_counter, + ) + ) + + subevent = Subevent( + start_acl_connection_event=event.start_acl_conn_event_counter, + frequency_compensation=event.frequency_compensation, + ranging_abort_reason=event.procedure_done_status, + ranging_done_status=event.procedure_done_status, + subevent_done_status=event.subevent_done_status, + subevent_abort_reason=event.abort_reason, + reference_power_level=event.reference_power_level, + ) + ranging_data.subevents.append(subevent) + client.active_procedure_counter[event.config_id] = procedure_counter + self.ranging_data_ready_characteristic.value = struct.pack( + ' None: + if not (connection := self.device.lookup_connection(event.connection_handle)): + logger.error( + "Subevent for unknown connection 0x%04X", event.connection_handle + ) + return + if not (client := self.clients[connection]): + return + procedure_counter = client.active_procedure_counter[event.config_id] + ranging_data = client.ranging_data_table[procedure_counter] + subevent = ranging_data.subevents[-1] + subevent.ranging_done_status = event.procedure_done_status + subevent.subevent_done_status = event.subevent_done_status + subevent.steps.extend( + [Step(mode, data) for mode, data in zip(event.step_mode, event.step_data)] + ) + + if event.procedure_done_status == hci.CsDoneStatus.ALL_RESULTS_COMPLETED: + self.device.abort_on( + 'flush', + self.device.notify_subscribers(self.ranging_data_ready_characteristic), + ) + if client.active_mode == self.Mode.REAL_TIME: + connection.abort_on( + 'disconnection', + self.send_ranging_data( + connection=connection, + data=bytes(ranging_data), + ), + ) + + async def _on_write_control_point( + self, connection: device.Connection, data: bytes + ) -> None: + op_code = data[0] + response: bytes + if op_code == RasControlPointOpCode.GET_RANGING_DATA: + ranging_counter = struct.unpack_from(' None: + connection.once( + 'disconnection', + functools.partial(self._on_disconnection, connection), + ) + + def _on_disconnection(self, connection: device.Connection, reason: int) -> None: + del reason + self.clients.pop(connection, None) + + async def send_ranging_data( + self, + connection: device.Connection, + data: bytes, + ) -> None: + mps = connection.att_mtu - 6 + client = self.clients[connection] + if client.active_mode == self.Mode.ON_DEMAND: + characteristic = self.on_demand_ranging_data_characteristic + elif client.active_mode == self.Mode.REAL_TIME: + if not self.real_time_ranging_data_characteristic: + logger.error( + "Trying to send real time ranging data, but it's not supported." + ) + return + characteristic = self.real_time_ranging_data_characteristic + else: + logger.debug('%s does not enable ranging data.', client) + return + + if client.cccd_value & gatt.ClientCharacteristicConfigurationBits.NOTIFICATION: + method = self.device.notify_subscriber + elif client.cccd_value & gatt.ClientCharacteristicConfigurationBits.INDICATION: + method = self.device.indicate_subscriber + else: + logger.debug('%s does not enable ranging data.', client) + return + + for index, offset in enumerate(range(0, len(data), mps)): + fragment = data[offset : offset + mps] + header = SegmentationHeader( + is_first=(offset == 0), + is_last=(offset + len(fragment) >= len(data)), + segment_index=index, + ) + await method( + connection=connection, + attribute=characteristic, + value=bytes(header) + fragment, + force=True, + ) + + +# ----------------------------------------------------------------------------- +class RangingServiceProxy(gatt_client.ProfileServiceProxy): + SERVICE_CLASS = RangingService + + ras_features_characteristic: gatt_client.CharacteristicProxy[bytes] + on_demand_ranging_data_characteristic: gatt_client.CharacteristicProxy[bytes] + ras_control_point_characteristic: gatt_client.CharacteristicProxy[bytes] + ranging_data_ready_characteristic: gatt_client.CharacteristicProxy[bytes] + ranging_data_overwritten_characteristic: gatt_client.CharacteristicProxy[bytes] + + # Optional. + real_time_ranging_data_characteristic: ( + gatt_client.CharacteristicProxy[bytes] | None + ) = None + + def __init__(self, service_proxy: gatt_client.ServiceProxy): + self.service_proxy = service_proxy + + for attribute, uuid in { + "ras_features_characteristic": gatt.GATT_RAS_FEATURES_CHARACTERISTIC, + "on_demand_ranging_data_characteristic": gatt.GATT_ON_DEMAND_RANGING_DATA_CHARACTERISTIC, + "ras_control_point_characteristic": gatt.GATT_RAS_CONTROL_POINT_CHARACTERISTIC, + "ranging_data_ready_characteristic": gatt.GATT_RANGING_DATA_READY_CHARACTERISTIC, + "ranging_data_overwritten_characteristic": gatt.GATT_RANGING_DATA_OVERWRITTEN_CHARACTERISTIC, + }.items(): + if not (characteristics := service_proxy.get_characteristics_by_uuid(uuid)): + raise gatt.InvalidServiceError( + f"Missing mandatory characteristic {uuid}" + ) + setattr(self, attribute, characteristics[0]) + + if characteristics := service_proxy.get_characteristics_by_uuid( + gatt.GATT_REAL_TIME_RANGING_DATA_CHARACTERISTIC + ): + self.real_time_ranging_data_characteristic = characteristics[0] diff --git a/examples/run_channel_sounding.py b/examples/run_channel_sounding.py index 346b7755..54409fff 100644 --- a/examples/run_channel_sounding.py +++ b/examples/run_channel_sounding.py @@ -27,6 +27,7 @@ from bumble import hci from bumble.device import Connection, Device, ChannelSoundingCapabilities from bumble.transport import open_transport_or_link +from bumble.profiles import rap # From https://cs.android.com/android/platform/superproject/main/+/main:packages/modules/Bluetooth/system/gd/hci/distance_measurement_manager.cc. CS_TONE_ANTENNA_CONFIG_MAPPING_TABLE = [ @@ -86,11 +87,37 @@ async def main() -> None: ) await device.power_on() assert (local_cs_capabilities := device.cs_capabilities) + ras = rap.RangingService( + device=device, ras_features=rap.RasFeatures.REAL_TIME_RANGING_DATA + ) + device.add_service(ras) if len(sys.argv) == 3: + advertising_data = bytes( + core.AdvertisingData( + [ + ( + core.AdvertisingData.COMPLETE_LOCAL_NAME, + bytes(device.name, 'utf-8'), + ), + ( + core.AdvertisingData.FLAGS, + bytes( + [core.AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG] + ), + ), + ( + core.AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, + bytes(ras.uuid), + ), + ] + ) + ) print('<<< Start Advertising') await device.start_advertising( - own_address_type=hci.OwnAddressType.RANDOM, auto_restart=True + own_address_type=hci.OwnAddressType.RANDOM, + auto_restart=True, + advertising_data=advertising_data, ) def on_cs_capabilities( diff --git a/tests/rap_test.py b/tests/rap_test.py new file mode 100644 index 00000000..56df380d --- /dev/null +++ b/tests/rap_test.py @@ -0,0 +1,376 @@ +# Copyright 2021-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations + +import asyncio +import pytest + +from . import test_utils + +from bumble import att +from bumble import device +from bumble import gatt +from bumble import hci +from bumble.profiles import rap + + +# ----------------------------------------------------------------------------- +def make_config(role: hci.CsRole, rtt_type: hci.RttType = hci.RttType.AA_ONLY): + return device.ChannelSoundingConfig( + config_id=0, + main_mode_type=0, + sub_mode_type=0, + min_main_mode_steps=0, + max_main_mode_steps=0, + main_mode_repetition=0, + mode_0_steps=0, + role=role, + rtt_type=rtt_type, + cs_sync_phy=0, + channel_map=b'', + channel_map_repetition=0, + channel_selection_type=0, + ch3c_shape=0, + ch3c_jump=0, + reserved=0, + t_ip1_time=0, + t_ip2_time=0, + t_fcs_time=0, + t_pm_time=0, + ) + + +# ----------------------------------------------------------------------------- +async def make_connections( + ras_features: rap.RasFeatures, +) -> tuple[rap.RangingService, rap.RangingServiceProxy]: + devices = await test_utils.TwoDevices.create_with_connection() + assert (server_connection := devices.connections[0]) + assert (client_connection := devices.connections[1]) + # Mock encryption. + server_connection.encryption = 1 + client_connection.encryption = 1 + server = rap.RangingService(devices[0], ras_features) + devices[0].add_service(server) + + peer = device.Peer(client_connection) + client = await peer.discover_service_and_create_proxy(rap.RangingServiceProxy) + assert client + return server, client + + +# ----------------------------------------------------------------------------- +def test_parse_ranging_data_initiator_without_sounding_sequence() -> None: + config = make_config(role=hci.CsRole.INITIATOR) + expected_ranging_data = rap.RangingData( + ranging_header=rap.RangingHeader( + configuration_id=0, + selected_tx_power=-1, + antenna_paths_mask=0x0F, + ranging_counter=2, + ), + subevents=[ + rap.Subevent( + start_acl_connection_event=0, + frequency_compensation=1, + ranging_done_status=2, + ranging_abort_reason=3, + subevent_abort_reason=4, + subevent_done_status=5, + reference_power_level=-2, + steps=[ + rap.Step(mode=0, data=bytes(5)), + rap.Step(mode=1, data=bytes(6)), + rap.Step(mode=2, data=bytes(21)), + rap.Step(mode=3, data=bytes(27)), + ], + ) + ], + ) + + assert ( + rap.RangingData.from_bytes(bytes(expected_ranging_data), config) + == expected_ranging_data + ) + + +# ----------------------------------------------------------------------------- +def test_parse_ranging_data_reflector_without_sounding_sequence() -> None: + config = make_config(role=hci.CsRole.REFLECTOR) + expected_ranging_data = rap.RangingData( + ranging_header=rap.RangingHeader( + configuration_id=0, + selected_tx_power=-1, + antenna_paths_mask=0x0F, + ranging_counter=2, + ), + subevents=[ + rap.Subevent( + start_acl_connection_event=0, + frequency_compensation=1, + ranging_done_status=2, + ranging_abort_reason=3, + subevent_abort_reason=4, + subevent_done_status=5, + reference_power_level=-2, + steps=[ + rap.Step(mode=0, data=bytes(3)), + rap.Step(mode=1, data=bytes(6)), + rap.Step(mode=2, data=bytes(21)), + rap.Step(mode=3, data=bytes(27)), + ], + ) + ], + ) + + assert ( + rap.RangingData.from_bytes(bytes(expected_ranging_data), config) + == expected_ranging_data + ) + + +# ----------------------------------------------------------------------------- +def test_parse_ranging_data_initiator_with_sounding_sequence() -> None: + config = make_config( + role=hci.CsRole.INITIATOR, rtt_type=hci.RttType.SOUNDING_SEQUENCE_32_BIT + ) + expected_ranging_data = rap.RangingData( + ranging_header=rap.RangingHeader( + configuration_id=0, + selected_tx_power=-1, + antenna_paths_mask=0x0F, + ranging_counter=2, + ), + subevents=[ + rap.Subevent( + start_acl_connection_event=0, + frequency_compensation=1, + ranging_done_status=2, + ranging_abort_reason=3, + subevent_abort_reason=4, + subevent_done_status=5, + reference_power_level=-2, + steps=[ + rap.Step(mode=0, data=bytes(5)), + rap.Step(mode=1, data=bytes(12)), + rap.Step(mode=2, data=bytes(21)), + rap.Step(mode=3, data=bytes(33)), + ], + ) + ], + ) + + assert ( + rap.RangingData.from_bytes(bytes(expected_ranging_data), config) + == expected_ranging_data + ) + + +# ----------------------------------------------------------------------------- +def test_parse_ranging_data_reflector_with_sounding_sequence() -> None: + config = make_config( + role=hci.CsRole.REFLECTOR, + rtt_type=hci.RttType.SOUNDING_SEQUENCE_96_BIT, + ) + expected_ranging_data = rap.RangingData( + ranging_header=rap.RangingHeader( + configuration_id=0, + selected_tx_power=-1, + antenna_paths_mask=0x0F, + ranging_counter=2, + ), + subevents=[ + rap.Subevent( + start_acl_connection_event=0, + frequency_compensation=1, + ranging_done_status=2, + ranging_abort_reason=3, + subevent_abort_reason=4, + subevent_done_status=5, + reference_power_level=-2, + steps=[ + rap.Step(mode=0, data=bytes(3)), + rap.Step(mode=1, data=bytes(12)), + rap.Step(mode=2, data=bytes(21)), + rap.Step(mode=3, data=bytes(33)), + ], + ) + ] + * 2, + ) + + assert ( + rap.RangingData.from_bytes(bytes(expected_ranging_data), config) + == expected_ranging_data + ) + + +# ----------------------------------------------------------------------------- +async def test_subscribe_on_demand_cccd() -> None: + server, client = await make_connections(rap.RasFeatures.REAL_TIME_RANGING_DATA) + server_connection = next(iter(server.device.connections.values())) + + assert client.real_time_ranging_data_characteristic + await client.on_demand_ranging_data_characteristic.subscribe(prefer_notify=True) + assert ( + server.clients[server_connection].active_mode + == rap.RangingService.Mode.ON_DEMAND + ) + assert ( + server.clients[server_connection].cccd_value + == gatt.ClientCharacteristicConfigurationBits.NOTIFICATION + ) + assert ( + cccd := client.on_demand_ranging_data_characteristic.get_descriptor( + gatt.GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR + ) + ) + assert ( + int.from_bytes(await cccd.read_value(), 'little') + == gatt.ClientCharacteristicConfigurationBits.NOTIFICATION + ) + + await client.on_demand_ranging_data_characteristic.unsubscribe() + assert ( + server.clients[server_connection].active_mode + == rap.RangingService.Mode.INACTIVE + ) + + +# ----------------------------------------------------------------------------- +async def test_subscribe_real_time_cccd() -> None: + server, client = await make_connections(rap.RasFeatures.REAL_TIME_RANGING_DATA) + server_connection = next(iter(server.device.connections.values())) + + assert client.real_time_ranging_data_characteristic + await client.real_time_ranging_data_characteristic.subscribe(prefer_notify=True) + assert ( + server.clients[server_connection].active_mode + == rap.RangingService.Mode.REAL_TIME + ) + assert ( + server.clients[server_connection].cccd_value + == gatt.ClientCharacteristicConfigurationBits.NOTIFICATION + ) + assert ( + cccd := client.real_time_ranging_data_characteristic.get_descriptor( + gatt.GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR + ) + ) + assert ( + int.from_bytes(await cccd.read_value(), 'little') + == gatt.ClientCharacteristicConfigurationBits.NOTIFICATION + ) + + await client.real_time_ranging_data_characteristic.unsubscribe() + assert ( + server.clients[server_connection].active_mode + == rap.RangingService.Mode.INACTIVE + ) + + +# ----------------------------------------------------------------------------- +async def test_read_cccd_without_on_inactive() -> None: + server, client = await make_connections(rap.RasFeatures.REAL_TIME_RANGING_DATA) + + assert client.real_time_ranging_data_characteristic + await client.real_time_ranging_data_characteristic.discover_descriptors() + assert ( + cccd := client.real_time_ranging_data_characteristic.get_descriptor( + gatt.GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR + ) + ) + assert (await cccd.read_value()) == bytes(2) + + +# ----------------------------------------------------------------------------- +async def test_subscribe_real_time_when_on_demand_is_on() -> None: + server, client = await make_connections(rap.RasFeatures.REAL_TIME_RANGING_DATA) + assert client.real_time_ranging_data_characteristic + + await client.on_demand_ranging_data_characteristic.subscribe(prefer_notify=True) + with pytest.raises(att.ATT_Error): + await client.real_time_ranging_data_characteristic.subscribe(prefer_notify=True) + + +# ----------------------------------------------------------------------------- +async def test_subscribe_on_demand_when_real_time_is_on() -> None: + server, client = await make_connections(rap.RasFeatures.REAL_TIME_RANGING_DATA) + assert client.real_time_ranging_data_characteristic + + await client.real_time_ranging_data_characteristic.subscribe(prefer_notify=True) + with pytest.raises(att.ATT_Error): + await client.on_demand_ranging_data_characteristic.subscribe(prefer_notify=True) + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize('prefer_notify,', (True, False)) +async def test_send_ranging_data_on_demand(prefer_notify: bool) -> None: + server, client = await make_connections(rap.RasFeatures.REAL_TIME_RANGING_DATA) + server_connection = next(iter(server.device.connections.values())) + + notifications = asyncio.Queue[bytes]() + await client.on_demand_ranging_data_characteristic.subscribe( + notifications.put_nowait, prefer_notify=prefer_notify + ) + expected_data = bytes([i % 256 for i in range(4096)]) + await server.send_ranging_data(server_connection, expected_data) + + actual_data = b'' + while True: + notification = await asyncio.wait_for(notifications.get(), 0.1) + segmentation_header = rap.SegmentationHeader.from_bytes(notification) + actual_data += notification[1:] + if segmentation_header.is_last: + break + assert actual_data == expected_data + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize('prefer_notify,', (True, False)) +async def test_send_ranging_data_real_time(prefer_notify: bool) -> None: + server, client = await make_connections(rap.RasFeatures.REAL_TIME_RANGING_DATA) + assert client.real_time_ranging_data_characteristic + server_connection = next(iter(server.device.connections.values())) + + notifications = asyncio.Queue[bytes]() + await client.real_time_ranging_data_characteristic.subscribe( + notifications.put_nowait, prefer_notify=prefer_notify + ) + expected_data = bytes([i % 256 for i in range(4096)]) + await server.send_ranging_data(server_connection, expected_data) + + actual_data = b'' + while True: + notification = await asyncio.wait_for(notifications.get(), 0.1) + segmentation_header = rap.SegmentationHeader.from_bytes(notification) + actual_data += notification[1:] + if segmentation_header.is_last: + break + assert actual_data == expected_data + + +# ----------------------------------------------------------------------------- +async def test_send_ranging_data_inactive() -> None: + server, client = await make_connections(rap.RasFeatures.REAL_TIME_RANGING_DATA) + server_connection = next(iter(server.device.connections.values())) + await client.on_demand_ranging_data_characteristic.subscribe() + await client.on_demand_ranging_data_characteristic.unsubscribe() + + expected_data = bytes([i % 256 for i in range(4096)]) + await server.send_ranging_data(server_connection, expected_data)