From df697c6513ed443dbb4497338cb99fc837c5a7c4 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 18 Dec 2025 15:23:30 +0800 Subject: [PATCH] Add EATT Support --- apps/lea_unicast/app.py | 1 + bumble/att.py | 77 ++++++++- bumble/device.py | 39 +++-- bumble/gatt.py | 4 +- bumble/gatt_client.py | 112 ++++++++---- bumble/gatt_server.py | 327 ++++++++++++++++++++++-------------- bumble/l2cap.py | 6 + examples/run_gatt_client.py | 21 ++- tests/gatt_test.py | 185 +++++++++++++++++++- 9 files changed, 578 insertions(+), 194 deletions(-) diff --git a/apps/lea_unicast/app.py b/apps/lea_unicast/app.py index c96bb10c..4e31c0eb 100644 --- a/apps/lea_unicast/app.py +++ b/apps/lea_unicast/app.py @@ -298,6 +298,7 @@ async def run(self) -> None: advertising_interval_max=25, address=Address('F1:F2:F3:F4:F5:F6'), identity_address_type=Address.RANDOM_DEVICE_ADDRESS, + eatt_enabled=True, ) device_config.le_enabled = True diff --git a/bumble/att.py b/bumble/att.py index 89adaf9d..6e7c989f 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -34,10 +34,13 @@ TYPE_CHECKING, ClassVar, Generic, + TypeAlias, TypeVar, ) -from bumble import hci, utils +from typing_extensions import TypeIs + +from bumble import hci, l2cap, utils from bumble.colors import color from bumble.core import UUID, InvalidOperationError, ProtocolError from bumble.hci import HCI_Object @@ -50,6 +53,14 @@ _T = TypeVar('_T') +Bearer: TypeAlias = "Connection | l2cap.LeCreditBasedChannel" +EnhancedBearer: TypeAlias = l2cap.LeCreditBasedChannel + + +def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]: + return isinstance(bearer, EnhancedBearer) + + # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- @@ -58,6 +69,7 @@ ATT_CID = 0x04 ATT_PSM = 0x001F +EATT_PSM = 0x0027 class Opcode(hci.SpecableEnum): ATT_ERROR_RESPONSE = 0x01 @@ -780,6 +792,43 @@ def write(self, connection: Connection, value: _T) -> Awaitable[None] | None: return self._write(connection, value) +# ----------------------------------------------------------------------------- +class AttributeValueV2(Generic[_T]): + ''' + Attribute value compatible with enhanced bearers. + + The only difference between AttributeValue and AttributeValueV2 is that the actual + bearer (ACL connection for un-enhanced bearer, L2CAP channel for enhanced bearer) + will be passed into read and write callbacks in V2, while in V1 it is always + the base ACL connection. + + This is only required when attributes must distinguish bearers, otherwise normal + `AttributeValue` objects are also applicable in enhanced bearers. + ''' + + def __init__( + self, + read: Callable[[Bearer], Awaitable[_T]] | Callable[[Bearer], _T] | None = None, + write: ( + Callable[[Bearer, _T], Awaitable[None]] + | Callable[[Bearer, _T], None] + | None + ) = None, + ): + self._read = read + self._write = write + + def read(self, bearer: Bearer) -> _T | Awaitable[_T]: + if self._read is None: + raise InvalidOperationError('AttributeValue has no read function') + return self._read(bearer) + + def write(self, bearer: Bearer, value: _T) -> Awaitable[None] | None: + if self._write is None: + raise InvalidOperationError('AttributeValue has no write function') + return self._write(bearer, value) + + # ----------------------------------------------------------------------------- class Attribute(utils.EventEmitter, Generic[_T]): class Permissions(enum.IntFlag): @@ -855,7 +904,8 @@ def encode_value(self, value: _T) -> bytes: def decode_value(self, value: bytes) -> _T: return value # type: ignore - async def read_value(self, connection: Connection) -> bytes: + async def read_value(self, bearer: Bearer) -> bytes: + connection = bearer.connection if is_enhanced_bearer(bearer) else bearer if ( (self.permissions & self.READ_REQUIRES_ENCRYPTION) and connection is not None @@ -890,6 +940,17 @@ async def read_value(self, connection: Connection) -> bytes: raise ATT_Error( error_code=error.error_code, att_handle=self.handle ) from error + elif isinstance(self.value, AttributeValueV2): + try: + read_value = self.value.read(bearer) + if inspect.isawaitable(read_value): + value = await read_value + else: + value = read_value + except ATT_Error as error: + raise ATT_Error( + error_code=error.error_code, att_handle=self.handle + ) from error else: value = self.value @@ -897,7 +958,8 @@ async def read_value(self, connection: Connection) -> bytes: return b'' if value is None else self.encode_value(value) - async def write_value(self, connection: Connection, value: bytes) -> None: + async def write_value(self, bearer: Bearer, value: bytes) -> None: + connection = bearer.connection if is_enhanced_bearer(bearer) else bearer if ( (self.permissions & self.WRITE_REQUIRES_ENCRYPTION) and connection is not None @@ -931,6 +993,15 @@ async def write_value(self, connection: Connection, value: bytes) -> None: raise ATT_Error( error_code=error.error_code, att_handle=self.handle ) from error + elif isinstance(self.value, AttributeValueV2): + try: + result = self.value.write(bearer, decoded_value) + if inspect.isawaitable(result): + await result + except ATT_Error as error: + raise ATT_Error( + error_code=error.error_code, att_handle=self.handle + ) from error else: self.value = decoded_value diff --git a/bumble/device.py b/bumble/device.py index 9c4c59ab..568d868c 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -41,6 +41,7 @@ from typing_extensions import Self from bumble import ( + att, core, data_types, gatt, @@ -53,7 +54,6 @@ smp, utils, ) -from bumble.att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU from bumble.colors import color from bumble.core import ( AdvertisingData, @@ -1743,7 +1743,6 @@ class Connection(utils.CompositeEventEmitter): EVENT_CONNECTION_PARAMETERS_UPDATE_FAILURE = "connection_parameters_update_failure" EVENT_CONNECTION_PHY_UPDATE = "connection_phy_update" EVENT_CONNECTION_PHY_UPDATE_FAILURE = "connection_phy_update_failure" - EVENT_CONNECTION_ATT_MTU_UPDATE = "connection_att_mtu_update" EVENT_CONNECTION_DATA_LENGTH_CHANGE = "connection_data_length_change" EVENT_CHANNEL_SOUNDING_CAPABILITIES_FAILURE = ( "channel_sounding_capabilities_failure" @@ -1846,7 +1845,7 @@ def __init__( self.encryption_key_size = 0 self.authenticated = False self.sc = False - self.att_mtu = ATT_DEFAULT_MTU + self.att_mtu = att.ATT_DEFAULT_MTU self.data_length = DEVICE_DEFAULT_DATA_LENGTH self.gatt_client = gatt_client.Client(self) # Per-connection client self.gatt_server = ( @@ -1996,6 +1995,15 @@ async def get_remote_le_features(self) -> hci.LeFeatureMask: self.peer_le_features = await self.device.get_remote_le_features(self) return self.peer_le_features + def on_att_mtu_update(self, mtu: int): + logger.debug( + f'*** Connection ATT MTU Update: [0x{self.handle:04X}] ' + f'{self.peer_address} as {self.role_name}, ' + f'{mtu}' + ) + self.att_mtu = mtu + self.emit(self.EVENT_CONNECTION_ATT_MTU_UPDATE) + @property def data_packet_queue(self) -> DataPacketQueue | None: return self.device.host.get_data_packet_queue(self.handle) @@ -2079,6 +2087,7 @@ class DeviceConfiguration: l2cap.L2CAP_Information_Request.ExtendedFeatures.FCS_OPTION, l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE, ) + eatt_enabled: bool = False def __post_init__(self) -> None: self.gatt_services: list[dict[str, Any]] = [] @@ -2497,7 +2506,10 @@ def __init__( add_gap_service=config.gap_service_enabled, add_gatt_service=config.gatt_service_enabled, ) - self.l2cap_channel_manager.register_fixed_channel(ATT_CID, self.on_gatt_pdu) + self.l2cap_channel_manager.register_fixed_channel(att.ATT_CID, self.on_gatt_pdu) + + if self.config.eatt_enabled: + self.gatt_server.register_eatt() # Forward some events utils.setup_event_forwarding( @@ -5140,7 +5152,11 @@ def add_default_services( if add_gap_service: self.gatt_server.add_service(GenericAccessService(self.name)) if add_gatt_service: - self.gatt_service = gatt_service.GenericAttributeProfileService() + self.gatt_service = gatt_service.GenericAttributeProfileService( + gatt.ServerSupportedFeatures.EATT_SUPPORTED + if self.config.eatt_enabled + else None + ) self.gatt_server.add_service(self.gatt_service) async def notify_subscriber( @@ -6240,17 +6256,6 @@ def on_le_subrate_change( ) connection.emit(connection.EVENT_LE_SUBRATE_CHANGE) - @host_event_handler - @with_connection_from_handle - def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int): - logger.debug( - f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] ' - f'{connection.peer_address} as {connection.role_name}, ' - f'{att_mtu}' - ) - connection.att_mtu = att_mtu - connection.emit(connection.EVENT_CONNECTION_ATT_MTU_UPDATE) - @host_event_handler @with_connection_from_handle def on_connection_data_length_change( @@ -6437,7 +6442,7 @@ def on_pairing_failure(self, connection: Connection, reason: int) -> None: @with_connection_from_handle def on_gatt_pdu(self, connection: Connection, pdu: bytes): # Parse the L2CAP payload into an ATT PDU object - att_pdu = ATT_PDU.from_bytes(pdu) + att_pdu = att.ATT_PDU.from_bytes(pdu) # Conveniently, even-numbered op codes are client->server and # odd-numbered ones are server->client diff --git a/bumble/gatt.py b/bumble/gatt.py index 836d6c87..11a83788 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -31,7 +31,7 @@ from collections.abc import Iterable, Sequence from typing import TypeVar -from bumble.att import Attribute, AttributeValue +from bumble.att import Attribute, AttributeValue, AttributeValueV2 from bumble.colors import color from bumble.core import UUID, BaseBumbleError @@ -579,7 +579,7 @@ class Descriptor(Attribute): def __str__(self) -> str: if isinstance(self.value, bytes): value_str = self.value.hex() - elif isinstance(self.value, CharacteristicValue): + elif isinstance(self.value, (AttributeValue, AttributeValueV2)): value_str = '' else: value_str = '<...>' diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index 487961ce..87823663 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -26,6 +26,7 @@ from __future__ import annotations import asyncio +import functools import logging import struct from collections.abc import Callable, Iterable @@ -35,9 +36,10 @@ Any, Generic, TypeVar, + overload, ) -from bumble import att, core, utils +from bumble import att, core, l2cap, utils from bumble.colors import color from bumble.core import UUID, InvalidStateError from bumble.gatt import ( @@ -54,12 +56,12 @@ ) from bumble.hci import HCI_Constant +if TYPE_CHECKING: + from bumble import device as device_module + # ----------------------------------------------------------------------------- # Typing # ----------------------------------------------------------------------------- -if TYPE_CHECKING: - from bumble.device import Connection - _T = TypeVar('_T') # ----------------------------------------------------------------------------- @@ -267,8 +269,8 @@ class Client: pending_response: asyncio.futures.Future[att.ATT_PDU] | None pending_request: att.ATT_PDU | None - def __init__(self, connection: Connection) -> None: - self.connection = connection + def __init__(self, bearer: att.Bearer) -> None: + self.bearer = bearer self.mtu_exchange_done = False self.request_semaphore = asyncio.Semaphore(1) self.pending_request = None @@ -278,21 +280,78 @@ def __init__(self, connection: Connection) -> None: self.services = [] self.cached_values = {} - connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection) + if att.is_enhanced_bearer(bearer): + bearer.on(bearer.EVENT_CLOSE, self.on_disconnection) + self._bearer_id = ( + f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]' + ) + # Fill the mtu. + bearer.on_att_mtu_update(att.ATT_DEFAULT_MTU) + self.connection = bearer.connection + else: + bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection) + self._bearer_id = f'[0x{bearer.handle:04X}]' + self.connection = bearer + + @overload + @classmethod + async def connect_eatt( + cls, + connection: device_module.Connection, + spec: l2cap.LeCreditBasedChannelSpec | None = None, + ) -> Client: ... + + @overload + @classmethod + async def connect_eatt( + cls, + connection: device_module.Connection, + spec: l2cap.LeCreditBasedChannelSpec | None = None, + count: int = 1, + ) -> list[Client]: ... + + @classmethod + async def connect_eatt( + cls, + connection: device_module.Connection, + spec: l2cap.LeCreditBasedChannelSpec | None = None, + count: int = 1, + ) -> list[Client] | Client: + channels = await connection.device.l2cap_channel_manager.create_enhanced_credit_based_channels( + connection, + spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM), + count, + ) + + def on_pdu(client: Client, pdu: bytes): + client.on_gatt_pdu(att.ATT_PDU.from_bytes(pdu)) + + clients = [cls(channel) for channel in channels] + for channel, client in zip(channels, clients): + channel.sink = functools.partial(on_pdu, client) + channel.att_mtu = att.ATT_DEFAULT_MTU + return clients[0] if count == 1 else clients + + @property + def mtu(self) -> int: + return self.bearer.att_mtu + + @mtu.setter + def mtu(self, value: int) -> None: + self.bearer.on_att_mtu_update(value) def send_gatt_pdu(self, pdu: bytes) -> None: - self.connection.send_l2cap_pdu(att.ATT_CID, pdu) + if att.is_enhanced_bearer(self.bearer): + self.bearer.write(pdu) + else: + self.bearer.send_l2cap_pdu(att.ATT_CID, pdu) async def send_command(self, command: att.ATT_PDU) -> None: - logger.debug( - f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' - ) + logger.debug(f'GATT Command from client: {self._bearer_id} {command}') self.send_gatt_pdu(bytes(command)) async def send_request(self, request: att.ATT_PDU): - logger.debug( - f'GATT Request from client: [0x{self.connection.handle:04X}] {request}' - ) + logger.debug(f'GATT Request from client: {self._bearer_id} {request}') # Wait until we can send (only one pending command at a time for the connection) response = None @@ -321,10 +380,7 @@ async def send_request(self, request: att.ATT_PDU): def send_confirmation( self, confirmation: att.ATT_Handle_Value_Confirmation ) -> None: - logger.debug( - f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' - f'{confirmation}' - ) + logger.debug(f'GATT Confirmation from client: {self._bearer_id} {confirmation}') self.send_gatt_pdu(bytes(confirmation)) async def request_mtu(self, mtu: int) -> int: @@ -336,7 +392,7 @@ async def request_mtu(self, mtu: int) -> int: # We can only send one request per connection if self.mtu_exchange_done: - return self.connection.att_mtu + return self.mtu # Send the request self.mtu_exchange_done = True @@ -347,9 +403,9 @@ async def request_mtu(self, mtu: int) -> int: raise att.ATT_Error(error_code=response.error_code, message=response) # Compute the final MTU - self.connection.att_mtu = min(mtu, response.server_rx_mtu) + self.mtu = min(mtu, response.server_rx_mtu) - return self.connection.att_mtu + return self.mtu def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]: return [service for service in self.services if service.uuid == uuid] @@ -942,7 +998,7 @@ async def read_value( # If the value is the max size for the MTU, try to read more unless the caller # specifically asked not to do that attribute_value = response.attribute_value - if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1: + if not no_long_read and len(attribute_value) == self.mtu - 1: logger.debug('using READ BLOB to get the rest of the value') offset = len(attribute_value) while True: @@ -966,7 +1022,7 @@ async def read_value( part = response.part_attribute_value attribute_value += part - if len(part) < self.connection.att_mtu - 1: + if len(part) < self.mtu - 1: break offset += len(part) @@ -1062,14 +1118,13 @@ async def write_value( ) ) - def on_disconnection(self, _) -> None: + def on_disconnection(self, *args) -> None: + del args # unused. if self.pending_response and not self.pending_response.done(): self.pending_response.cancel() def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None: - logger.debug( - f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' - ) + logger.debug(f'GATT Response to client: {self._bearer_id} {att_pdu}') if att_pdu.op_code in att.ATT_RESPONSES: if self.pending_request is None: # Not expected! @@ -1099,8 +1154,7 @@ def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None: else: logger.warning( color( - '--- Ignoring GATT Response from ' - f'[0x{self.connection.handle:04X}]: ', + '--- Ignoring GATT Response from ' f'{self._bearer_id}: ', 'red', ) + str(att_pdu) diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index bfa20990..3d29a76c 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -32,9 +32,8 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, TypeVar -from bumble import att, utils +from bumble import att, core, l2cap, utils from bumble.colors import color -from bumble.core import UUID from bumble.gatt import ( GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, @@ -44,14 +43,13 @@ GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, Characteristic, CharacteristicDeclaration, - CharacteristicValue, Descriptor, IncludedServiceDeclaration, Service, ) if TYPE_CHECKING: - from bumble.device import Connection, Device + from bumble.device import Device # ----------------------------------------------------------------------------- # Logging @@ -65,6 +63,18 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517 +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _bearer_id(bearer: att.Bearer) -> str: + if att.is_enhanced_bearer(bearer): + return f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]' + else: + return f'[0x{bearer.handle:04X}]' + + # ----------------------------------------------------------------------------- # GATT Server # ----------------------------------------------------------------------------- @@ -72,9 +82,9 @@ class Server(utils.EventEmitter): attributes: list[att.Attribute] services: list[Service] attributes_by_handle: dict[int, att.Attribute] - subscribers: dict[int, dict[int, bytes]] - indication_semaphores: defaultdict[int, asyncio.Semaphore] - pending_confirmations: defaultdict[int, asyncio.futures.Future | None] + subscribers: dict[att.Bearer, dict[int, bytes]] + indication_semaphores: defaultdict[att.Bearer, asyncio.Semaphore] + pending_confirmations: defaultdict[att.Bearer, asyncio.futures.Future | None] EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription" @@ -96,8 +106,29 @@ def __init__(self, device: Device) -> None: def __str__(self) -> str: return "\n".join(map(str, self.attributes)) - def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None: - self.device.send_l2cap_pdu(connection_handle, att.ATT_CID, pdu) + def register_eatt( + self, spec: l2cap.LeCreditBasedChannelSpec | None = None + ) -> l2cap.LeCreditBasedChannelServer: + def on_channel(channel: l2cap.LeCreditBasedChannel): + logger.debug( + "New EATT Bearer Conenction=0x%04X CID=0x%04X", + channel.connection.handle, + channel.source_cid, + ) + channel.att_mtu = att.ATT_DEFAULT_MTU + channel.sink = lambda pdu: self.on_gatt_pdu( + channel, att.ATT_PDU.from_bytes(pdu) + ) + + return self.device.create_l2cap_server( + spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM), handler=on_channel + ) + + def send_gatt_pdu(self, bearer: att.Bearer, pdu: bytes) -> None: + if att.is_enhanced_bearer(bearer): + bearer.write(pdu) + else: + self.device.send_l2cap_pdu(bearer.handle, att.ATT_CID, pdu) def next_handle(self) -> int: return 1 + len(self.attributes) @@ -138,7 +169,7 @@ def get_attribute_group( None, ) - def get_service_attribute(self, service_uuid: UUID) -> Service | None: + def get_service_attribute(self, service_uuid: core.UUID) -> Service | None: return next( ( attribute @@ -151,7 +182,7 @@ def get_service_attribute(self, service_uuid: UUID) -> Service | None: ) def get_characteristic_attributes( - self, service_uuid: UUID, characteristic_uuid: UUID + self, service_uuid: core.UUID, characteristic_uuid: core.UUID ) -> tuple[CharacteristicDeclaration, Characteristic] | None: service_handle = self.get_service_attribute(service_uuid) if not service_handle: @@ -176,7 +207,10 @@ def get_characteristic_attributes( ) def get_descriptor_attribute( - self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID + self, + service_uuid: core.UUID, + characteristic_uuid: core.UUID, + descriptor_uuid: core.UUID, ) -> Descriptor | None: characteristics = self.get_characteristic_attributes( service_uuid, characteristic_uuid @@ -257,14 +291,7 @@ def add_service(self, service: Service) -> None: Descriptor( GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, att.Attribute.READABLE | att.Attribute.WRITEABLE, - CharacteristicValue( - read=lambda connection, characteristic=characteristic: self.read_cccd( - connection, characteristic - ), - write=lambda connection, value, characteristic=characteristic: self.write_cccd( - connection, characteristic, value - ), - ), + self.make_descriptor_value(characteristic), ) ) @@ -280,10 +307,21 @@ def add_services(self, services: Iterable[Service]) -> None: for service in services: self.add_service(service) - def read_cccd( - self, connection: Connection, characteristic: Characteristic - ) -> bytes: - subscribers = self.subscribers.get(connection.handle) + def make_descriptor_value( + self, characteristic: Characteristic + ) -> att.AttributeValueV2: + # It is necessary to use Attribute Value V2 here to identify the bearer of CCCD. + return att.AttributeValueV2( + lambda bearer, characteristic=characteristic: self.read_cccd( + bearer, characteristic + ), + write=lambda bearer, value, characteristic=characteristic: self.write_cccd( + bearer, characteristic, value + ), + ) + + def read_cccd(self, bearer: att.Bearer, characteristic: Characteristic) -> bytes: + subscribers = self.subscribers.get(bearer) cccd = None if subscribers: cccd = subscribers.get(characteristic.handle) @@ -292,12 +330,12 @@ def read_cccd( def write_cccd( self, - connection: Connection, + bearer: att.Bearer, characteristic: Characteristic, value: bytes, ) -> None: logger.debug( - f'Subscription update for connection=0x{connection.handle:04X}, ' + f'Subscription update for connection={_bearer_id(bearer)}, ' f'handle=0x{characteristic.handle:04X}: {value.hex()}' ) @@ -306,41 +344,60 @@ def write_cccd( logger.warning('CCCD value not 2 bytes long') return - cccds = self.subscribers.setdefault(connection.handle, {}) + cccds = self.subscribers.setdefault(bearer, {}) cccds[characteristic.handle] = value logger.debug(f'CCCDs: {cccds}') notify_enabled = value[0] & 0x01 != 0 indicate_enabled = value[0] & 0x02 != 0 characteristic.emit( characteristic.EVENT_SUBSCRIPTION, - connection, + bearer, notify_enabled, indicate_enabled, ) self.emit( self.EVENT_CHARACTERISTIC_SUBSCRIPTION, - connection, + bearer, characteristic, notify_enabled, indicate_enabled, ) - def send_response(self, connection: Connection, response: att.ATT_PDU) -> None: - logger.debug( - f'GATT Response from server: [0x{connection.handle:04X}] {response}' - ) - self.send_gatt_pdu(connection.handle, bytes(response)) + def send_response(self, bearer: att.Bearer, response: att.ATT_PDU) -> None: + logger.debug(f'GATT Response from server: {_bearer_id(bearer)} {response}') + self.send_gatt_pdu(bearer, bytes(response)) async def notify_subscriber( self, - connection: Connection, + bearer: att.Bearer, attribute: att.Attribute, value: bytes | None = None, force: bool = False, + ) -> None: + if att.is_enhanced_bearer(bearer) or force: + return await self._notify_single_subscriber(bearer, attribute, value, force) + else: + # If API is called to a Connection and not forced, try to notify all subscribed bearers on it. + bearers = [ + channel + for channel in self.device.l2cap_channel_manager.le_coc_channels.get( + bearer.handle, {} + ).values() + if channel.psm == att.EATT_PSM + ] + [bearer] + for bearer in bearers: + await self._notify_single_subscriber(bearer, attribute, value, force) + + async def _notify_single_subscriber( + self, + bearer: att.Bearer, + attribute: att.Attribute, + value: bytes | None, + force: bool, ) -> None: # Check if there's a subscriber if not force: - subscribers = self.subscribers.get(connection.handle) + subscribers = self.subscribers.get(bearer) if not subscribers: logger.debug('not notifying, no subscribers') return @@ -356,34 +413,53 @@ async def notify_subscriber( # Get or encode the value value = ( - await attribute.read_value(connection) + await attribute.read_value(bearer) if value is None else attribute.encode_value(value) ) # Truncate if needed - if len(value) > connection.att_mtu - 3: - value = value[: connection.att_mtu - 3] + if len(value) > bearer.att_mtu - 3: + value = value[: bearer.att_mtu - 3] # Notify notification = att.ATT_Handle_Value_Notification( attribute_handle=attribute.handle, attribute_value=value ) - logger.debug( - f'GATT Notify from server: [0x{connection.handle:04X}] {notification}' - ) - self.send_gatt_pdu(connection.handle, bytes(notification)) + logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}') + self.send_gatt_pdu(bearer, bytes(notification)) async def indicate_subscriber( self, - connection: Connection, + bearer: att.Bearer, attribute: att.Attribute, value: bytes | None = None, force: bool = False, + ) -> None: + if att.is_enhanced_bearer(bearer) or force: + return await self._notify_single_subscriber(bearer, attribute, value, force) + else: + # If API is called to a Connection and not forced, try to indicate all subscribed bearers on it. + bearers = [ + channel + for channel in self.device.l2cap_channel_manager.le_coc_channels.get( + bearer.handle, {} + ).values() + if channel.psm == att.EATT_PSM + ] + [bearer] + for bearer in bearers: + await self._indicate_single_bearer(bearer, attribute, value, force) + + async def _indicate_single_bearer( + self, + bearer: att.Bearer, + attribute: att.Attribute, + value: bytes | None, + force: bool, ) -> None: # Check if there's a subscriber if not force: - subscribers = self.subscribers.get(connection.handle) + subscribers = self.subscribers.get(bearer) if not subscribers: logger.debug('not indicating, no subscribers') return @@ -399,40 +475,38 @@ async def indicate_subscriber( # Get or encode the value value = ( - await attribute.read_value(connection) + await attribute.read_value(bearer) if value is None else attribute.encode_value(value) ) # Truncate if needed - if len(value) > connection.att_mtu - 3: - value = value[: connection.att_mtu - 3] + if len(value) > bearer.att_mtu - 3: + value = value[: bearer.att_mtu - 3] # Indicate indication = att.ATT_Handle_Value_Indication( attribute_handle=attribute.handle, attribute_value=value ) - logger.debug( - f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}' - ) + logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}') # Wait until we can send (only one pending indication at a time per connection) - async with self.indication_semaphores[connection.handle]: - assert self.pending_confirmations[connection.handle] is None + async with self.indication_semaphores[bearer]: + assert self.pending_confirmations[bearer] is None # Create a future value to hold the eventual response - pending_confirmation = self.pending_confirmations[connection.handle] = ( + pending_confirmation = self.pending_confirmations[bearer] = ( asyncio.get_running_loop().create_future() ) try: - self.send_gatt_pdu(connection.handle, bytes(indication)) + self.send_gatt_pdu(bearer, bytes(indication)) await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT) except asyncio.TimeoutError as error: logger.warning(color('!!! GATT Indicate timeout', 'red')) raise TimeoutError(f'GATT timeout for {indication.name}') from error finally: - self.pending_confirmations[connection.handle] = None + self.pending_confirmations[bearer] = None async def _notify_or_indicate_subscribers( self, @@ -441,24 +515,24 @@ async def _notify_or_indicate_subscribers( value: bytes | None = None, force: bool = False, ) -> None: - # Get all the connections for which there's at least one subscription - connections = [ - connection - for connection in [ - self.device.lookup_connection(connection_handle) - for (connection_handle, subscribers) in self.subscribers.items() - if force or subscribers.get(attribute.handle) - ] - if connection is not None + # Get all the bearers for which there's at least one subscription + bearers: list[att.Bearer] = [ + bearer + for bearer, subscribers in self.subscribers.items() + if force or subscribers.get(attribute.handle) ] # Indicate or notify for each connection - if connections: - coroutine = self.indicate_subscriber if indicate else self.notify_subscriber + if bearers: + coroutine = ( + self._indicate_single_bearer + if indicate + else self._notify_single_subscriber + ) await asyncio.wait( [ - asyncio.create_task(coroutine(connection, attribute, value, force)) - for connection in connections + asyncio.create_task(coroutine(bearer, attribute, value, force)) + for bearer in bearers ] ) @@ -480,21 +554,18 @@ async def indicate_subscribers( ): return await self._notify_or_indicate_subscribers(True, attribute, value, force) - def on_disconnection(self, connection: Connection) -> None: - if connection.handle in self.subscribers: - del self.subscribers[connection.handle] - if connection.handle in self.indication_semaphores: - del self.indication_semaphores[connection.handle] - if connection.handle in self.pending_confirmations: - del self.pending_confirmations[connection.handle] + def on_disconnection(self, bearer: att.Bearer) -> None: + self.subscribers.pop(bearer, None) + self.indication_semaphores.pop(bearer, None) + self.pending_confirmations.pop(bearer, None) - def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None: - logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}') + def on_gatt_pdu(self, bearer: att.Bearer, att_pdu: att.ATT_PDU) -> None: + logger.debug(f'GATT Request to server: {_bearer_id(bearer)} {att_pdu}') handler_name = f'on_{att_pdu.name.lower()}' handler = getattr(self, handler_name, None) if handler is not None: try: - handler(connection, att_pdu) + handler(bearer, att_pdu) except att.ATT_Error as error: logger.debug(f'normal exception returned by handler: {error}') response = att.ATT_Error_Response( @@ -502,7 +573,7 @@ def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None: attribute_handle_in_error=error.att_handle, error_code=error.error_code, ) - self.send_response(connection, response) + self.send_response(bearer, response) except Exception: logger.exception(color("!!! Exception in handler:", "red")) response = att.ATT_Error_Response( @@ -510,18 +581,18 @@ def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None: attribute_handle_in_error=0x0000, error_code=att.ATT_UNLIKELY_ERROR_ERROR, ) - self.send_response(connection, response) + self.send_response(bearer, response) raise else: # No specific handler registered if att_pdu.op_code in att.ATT_REQUESTS: # Invoke the generic handler - self.on_att_request(connection, att_pdu) + self.on_att_request(bearer, att_pdu) else: # Just ignore logger.warning( color( - f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ', + f'--- Ignoring GATT Request from {_bearer_id(bearer)}: ', 'red', ) + str(att_pdu) @@ -530,13 +601,14 @@ def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None: ####################################################### # ATT handlers ####################################################### - def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None: + def on_att_request(self, bearer: att.Bearer, pdu: att.ATT_PDU) -> None: ''' Handler for requests without a more specific handler ''' logger.warning( color( - f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red' + f'--- Unsupported ATT Request from {_bearer_id(bearer)}: ', + 'red', ) + str(pdu) ) @@ -545,29 +617,28 @@ def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None: attribute_handle_in_error=0x0000, error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR, ) - self.send_response(connection, response) + self.send_response(bearer, response) def on_att_exchange_mtu_request( - self, connection: Connection, request: att.ATT_Exchange_MTU_Request + self, bearer: att.Bearer, request: att.ATT_Exchange_MTU_Request ): ''' See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request ''' self.send_response( - connection, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu) + bearer, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu) ) # Compute the final MTU if request.client_rx_mtu >= att.ATT_DEFAULT_MTU: mtu = min(self.max_mtu, request.client_rx_mtu) - # Notify the device - self.device.on_connection_att_mtu_update(connection.handle, mtu) + bearer.on_att_mtu_update(mtu) else: logger.warning('invalid client_rx_mtu received, MTU not changed') def on_att_find_information_request( - self, connection: Connection, request: att.ATT_Find_Information_Request + self, bearer: att.Bearer, request: att.ATT_Find_Information_Request ): ''' See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request @@ -580,7 +651,7 @@ def on_att_find_information_request( or request.starting_handle > request.ending_handle ): self.send_response( - connection, + bearer, att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.starting_handle, @@ -590,7 +661,7 @@ def on_att_find_information_request( return # Build list of returned attributes - pdu_space_available = connection.att_mtu - 2 + pdu_space_available = bearer.att_mtu - 2 attributes: list[att.Attribute] = [] uuid_size = 0 for attribute in ( @@ -632,18 +703,18 @@ def on_att_find_information_request( error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR, ) - self.send_response(connection, response) + self.send_response(bearer, response) @utils.AsyncRunner.run_in_task() async def on_att_find_by_type_value_request( - self, connection: Connection, request: att.ATT_Find_By_Type_Value_Request + self, bearer: att.Bearer, request: att.ATT_Find_By_Type_Value_Request ): ''' See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request ''' # Build list of returned attributes - pdu_space_available = connection.att_mtu - 2 + pdu_space_available = bearer.att_mtu - 2 attributes = [] response: att.ATT_PDU async for attribute in ( @@ -652,7 +723,7 @@ async def on_att_find_by_type_value_request( if attribute.handle >= request.starting_handle and attribute.handle <= request.ending_handle and attribute.type == request.attribute_type - and (await attribute.read_value(connection)) == request.attribute_value + and (await attribute.read_value(bearer)) == request.attribute_value and pdu_space_available >= 4 ): # TODO: check permissions @@ -688,17 +759,17 @@ async def on_att_find_by_type_value_request( error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR, ) - self.send_response(connection, response) + self.send_response(bearer, response) @utils.AsyncRunner.run_in_task() async def on_att_read_by_type_request( - self, connection: Connection, request: att.ATT_Read_By_Type_Request + self, bearer: att.Bearer, request: att.ATT_Read_By_Type_Request ): ''' See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request ''' - pdu_space_available = connection.att_mtu - 2 + pdu_space_available = bearer.att_mtu - 2 response: att.ATT_PDU = att.ATT_Error_Response( request_opcode_in_error=request.op_code, @@ -716,7 +787,7 @@ async def on_att_read_by_type_request( and pdu_space_available ): try: - attribute_value = await attribute.read_value(connection) + attribute_value = await attribute.read_value(bearer) except att.ATT_Error as error: # If the first attribute is unreadable, return an error # Otherwise return attributes up to this point @@ -729,7 +800,7 @@ async def on_att_read_by_type_request( break # Check the attribute value size - max_attribute_size = min(connection.att_mtu - 4, 253) + max_attribute_size = min(bearer.att_mtu - 4, 253) if len(attribute_value) > max_attribute_size: # We need to truncate attribute_value = attribute_value[:max_attribute_size] @@ -756,11 +827,11 @@ async def on_att_read_by_type_request( else: logging.debug(f"not found {request}") - self.send_response(connection, response) + self.send_response(bearer, response) @utils.AsyncRunner.run_in_task() async def on_att_read_request( - self, connection: Connection, request: att.ATT_Read_Request + self, bearer: att.Bearer, request: att.ATT_Read_Request ): ''' See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request @@ -769,7 +840,7 @@ async def on_att_read_request( response: att.ATT_PDU if attribute := self.get_attribute(request.attribute_handle): try: - value = await attribute.read_value(connection) + value = await attribute.read_value(bearer) except att.ATT_Error as error: response = att.ATT_Error_Response( request_opcode_in_error=request.op_code, @@ -777,7 +848,7 @@ async def on_att_read_request( error_code=error.error_code, ) else: - value_size = min(connection.att_mtu - 1, len(value)) + value_size = min(bearer.att_mtu - 1, len(value)) response = att.ATT_Read_Response(attribute_value=value[:value_size]) else: response = att.ATT_Error_Response( @@ -785,11 +856,11 @@ async def on_att_read_request( attribute_handle_in_error=request.attribute_handle, error_code=att.ATT_INVALID_HANDLE_ERROR, ) - self.send_response(connection, response) + self.send_response(bearer, response) @utils.AsyncRunner.run_in_task() async def on_att_read_blob_request( - self, connection: Connection, request: att.ATT_Read_Blob_Request + self, bearer: att.Bearer, request: att.ATT_Read_Blob_Request ): ''' See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request @@ -798,7 +869,7 @@ async def on_att_read_blob_request( response: att.ATT_PDU if attribute := self.get_attribute(request.attribute_handle): try: - value = await attribute.read_value(connection) + value = await attribute.read_value(bearer) except att.ATT_Error as error: response = att.ATT_Error_Response( request_opcode_in_error=request.op_code, @@ -812,7 +883,7 @@ async def on_att_read_blob_request( attribute_handle_in_error=request.attribute_handle, error_code=att.ATT_INVALID_OFFSET_ERROR, ) - elif len(value) <= connection.att_mtu - 1: + elif len(value) <= bearer.att_mtu - 1: response = att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.attribute_handle, @@ -820,7 +891,7 @@ async def on_att_read_blob_request( ) else: part_size = min( - connection.att_mtu - 1, len(value) - request.value_offset + bearer.att_mtu - 1, len(value) - request.value_offset ) response = att.ATT_Read_Blob_Response( part_attribute_value=value[ @@ -833,11 +904,11 @@ async def on_att_read_blob_request( attribute_handle_in_error=request.attribute_handle, error_code=att.ATT_INVALID_HANDLE_ERROR, ) - self.send_response(connection, response) + self.send_response(bearer, response) @utils.AsyncRunner.run_in_task() async def on_att_read_by_group_type_request( - self, connection: Connection, request: att.ATT_Read_By_Group_Type_Request + self, bearer: att.Bearer, request: att.ATT_Read_By_Group_Type_Request ): ''' See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request @@ -852,10 +923,10 @@ async def on_att_read_by_group_type_request( attribute_handle_in_error=request.starting_handle, error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR, ) - self.send_response(connection, response) + self.send_response(bearer, response) return - pdu_space_available = connection.att_mtu - 2 + pdu_space_available = bearer.att_mtu - 2 attributes: list[tuple[int, int, bytes]] = [] for attribute in ( attribute @@ -867,9 +938,9 @@ async def on_att_read_by_group_type_request( ): # No need to catch permission errors here, since these attributes # must all be world-readable - attribute_value = await attribute.read_value(connection) + attribute_value = await attribute.read_value(bearer) # Check the attribute value size - max_attribute_size = min(connection.att_mtu - 6, 251) + max_attribute_size = min(bearer.att_mtu - 6, 251) if len(attribute_value) > max_attribute_size: # We need to truncate attribute_value = attribute_value[:max_attribute_size] @@ -904,11 +975,11 @@ async def on_att_read_by_group_type_request( error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR, ) - self.send_response(connection, response) + self.send_response(bearer, response) @utils.AsyncRunner.run_in_task() async def on_att_write_request( - self, connection: Connection, request: att.ATT_Write_Request + self, bearer: att.Bearer, request: att.ATT_Write_Request ): ''' See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request @@ -918,7 +989,7 @@ async def on_att_write_request( attribute = self.get_attribute(request.attribute_handle) if attribute is None: self.send_response( - connection, + bearer, att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.attribute_handle, @@ -932,7 +1003,7 @@ async def on_att_write_request( # Check the request parameters if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE: self.send_response( - connection, + bearer, att.ATT_Error_Response( request_opcode_in_error=request.op_code, attribute_handle_in_error=request.attribute_handle, @@ -944,7 +1015,7 @@ async def on_att_write_request( response: att.ATT_PDU try: # Accept the value - await attribute.write_value(connection, request.attribute_value) + await attribute.write_value(bearer, request.attribute_value) except att.ATT_Error as error: response = att.ATT_Error_Response( request_opcode_in_error=request.op_code, @@ -954,11 +1025,11 @@ async def on_att_write_request( else: # Done response = att.ATT_Write_Response() - self.send_response(connection, response) + self.send_response(bearer, response) @utils.AsyncRunner.run_in_task() async def on_att_write_command( - self, connection: Connection, request: att.ATT_Write_Command + self, bearer: att.Bearer, request: att.ATT_Write_Command ): ''' See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command @@ -977,22 +1048,20 @@ async def on_att_write_command( # Accept the value try: - await attribute.write_value(connection, request.attribute_value) + await attribute.write_value(bearer, request.attribute_value) except Exception: logger.exception('!!! ignoring exception') def on_att_handle_value_confirmation( self, - connection: Connection, + bearer: att.Bearer, confirmation: att.ATT_Handle_Value_Confirmation, ): ''' See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation ''' del confirmation # Unused. - if ( - pending_confirmation := self.pending_confirmations[connection.handle] - ) is None: + if (pending_confirmation := self.pending_confirmations[bearer]) is None: # Not expected! logger.warning( '!!! unexpected confirmation, there is no pending indication' diff --git a/bumble/l2cap.py b/bumble/l2cap.py index e79c3204..b595ec60 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -1552,6 +1552,7 @@ class State(enum.IntEnum): EVENT_OPEN = "open" EVENT_CLOSE = "close" + EVENT_ATT_MTU_UPDATE = "att_mtu_update" def __init__( self, @@ -1591,6 +1592,7 @@ def __init__( self.connection_result = None self.disconnection_result = None self.drained = asyncio.Event() + self.att_mtu = 0 # Filled by GATT client or server later. self.drained.set() @@ -1821,6 +1823,10 @@ def on_disconnection_response(self, response: L2CAP_Disconnection_Response) -> N self.disconnection_result.set_result(None) self.disconnection_result = None + def on_att_mtu_update(self, mtu: int) -> None: + self.att_mtu = mtu + self.emit(self.EVENT_ATT_MTU_UPDATE, mtu) + def flush_output(self) -> None: self.out_queue.clear() self.out_sdu = None diff --git a/examples/run_gatt_client.py b/examples/run_gatt_client.py index c0af85b0..4786a02e 100644 --- a/examples/run_gatt_client.py +++ b/examples/run_gatt_client.py @@ -19,10 +19,10 @@ import sys import bumble.logging +from bumble import gatt_client from bumble.colors import color from bumble.core import ProtocolError -from bumble.device import Device, Peer -from bumble.gatt import show_services +from bumble.device import Connection, Device from bumble.transport import open_transport from bumble.utils import AsyncRunner @@ -34,24 +34,27 @@ def __init__(self, device): @AsyncRunner.run_in_task() # pylint: disable=invalid-overridden-method - async def on_connection(self, connection): + async def on_connection(self, connection: Connection): print(f'=== Connected to {connection}') # Discover all services print('=== Discovering services') - peer = Peer(connection) - await peer.discover_services() - for service in peer.services: + if connection.device.config.eatt_enabled: + client = await gatt_client.Client.connect_eatt(connection) + else: + client = connection.gatt_client + await client.discover_services() + for service in client.services: await service.discover_characteristics() for characteristic in service.characteristics: await characteristic.discover_descriptors() print('=== Services discovered') - show_services(peer.services) + gatt_client.show_services(client.services) # Discover all attributes print('=== Discovering attributes') - attributes = await peer.discover_attributes() + attributes = await client.discover_attributes() for attribute in attributes: print(attribute) print('=== Attributes discovered') @@ -59,7 +62,7 @@ async def on_connection(self, connection): # Read all attributes for attribute in attributes: try: - value = await peer.read_value(attribute) + value = await client.read_value(attribute) print(color(f'0x{attribute.handle:04X} = {value.hex()}', 'green')) except ProtocolError as error: print(color(f'cannot read {attribute.handle:04X}:', 'red'), error) diff --git a/tests/gatt_test.py b/tests/gatt_test.py index ecb489db..f2f3204f 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -28,6 +28,7 @@ import pytest from typing_extensions import Self +from bumble import gatt_client, l2cap from bumble.att import ( ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_PDU, @@ -63,7 +64,6 @@ UTF8CharacteristicAdapter, UTF8CharacteristicProxyAdapter, ) -from bumble.gatt_client import CharacteristicProxy from .test_utils import Devices, TwoDevices, async_barrier @@ -140,7 +140,7 @@ def decode_value(self, value_bytes): await c.write_value(Mock(), bytes([122])) assert c.value == 122 - class FooProxy(CharacteristicProxy): + class FooProxy(gatt_client.CharacteristicProxy): def __init__(self, characteristic): super().__init__( characteristic.client, @@ -456,7 +456,7 @@ async def read_value(self, handle, no_long_read=False) -> bytes: async def write_value(self, handle, value, with_response=False): self.value = value - class TestAttributeProxy(CharacteristicProxy): + class TestAttributeProxy(gatt_client.CharacteristicProxy): def __init__(self, value) -> None: super().__init__(Client(value), 0, 0, None, 0) # type: ignore @@ -1425,10 +1425,10 @@ async def test_get_characteristics_by_uuid(): await peer.discover_characteristics() c = peer.get_characteristics_by_uuid(uuid=UUID('1234')) assert len(c) == 2 - assert isinstance(c[0], CharacteristicProxy) + assert isinstance(c[0], gatt_client.CharacteristicProxy) c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('ABCD')) assert len(c) == 1 - assert isinstance(c[0], CharacteristicProxy) + assert isinstance(c[0], gatt_client.CharacteristicProxy) c = peer.get_characteristics_by_uuid(uuid=UUID('1234'), service=UUID('AAAA')) assert len(c) == 0 @@ -1463,6 +1463,181 @@ async def test_write_return_error(): assert e.value.error_code == ErrorCode.VALUE_NOT_ALLOWED +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_eatt_read(): + devices = await TwoDevices.create_with_connection() + devices[1].gatt_server.register_eatt() + + characteristic = Characteristic( + '1234', + Characteristic.Properties.READ, + Characteristic.Permissions.READABLE, + b'9999', + ) + service = Service('ABCD', [characteristic]) + devices[1].add_service(service) + + client = await gatt_client.Client.connect_eatt(devices.connections[0]) + await client.discover_services() + service_proxy = client.get_services_by_uuid(service.uuid)[0] + await service_proxy.discover_characteristics() + characteristic_proxy = service_proxy.get_characteristics_by_uuid( + characteristic.uuid + )[0] + assert await characteristic_proxy.read_value() == b'9999' + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_eatt_write(): + devices = await TwoDevices.create_with_connection() + devices[1].gatt_server.register_eatt() + + write_queue = asyncio.Queue() + characteristic = Characteristic( + '1234', + Characteristic.Properties.WRITE, + Characteristic.Permissions.WRITEABLE, + CharacteristicValue(write=lambda *args: write_queue.put_nowait(args)), + ) + service = Service('ABCD', [characteristic]) + devices[1].add_service(service) + + client = await gatt_client.Client.connect_eatt(devices.connections[0]) + await client.discover_services() + service_proxy = client.get_services_by_uuid(service.uuid)[0] + await service_proxy.discover_characteristics() + characteristic_proxy = service_proxy.get_characteristics_by_uuid( + characteristic.uuid + )[0] + await characteristic_proxy.write_value(b'9999') + assert await write_queue.get() == (devices.connections[1], b'9999') + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_eatt_notify(): + devices = await TwoDevices.create_with_connection() + devices[1].gatt_server.register_eatt() + + characteristic = Characteristic( + '1234', + Characteristic.Properties.NOTIFY, + Characteristic.Permissions.WRITEABLE, + ) + service = Service('ABCD', [characteristic]) + devices[1].add_service(service) + + clients = [ + ( + devices.connections[0].gatt_client, + asyncio.Queue[bytes](), + ), + ( + await gatt_client.Client.connect_eatt(devices.connections[0]), + asyncio.Queue[bytes](), + ), + ( + await gatt_client.Client.connect_eatt(devices.connections[0]), + asyncio.Queue[bytes](), + ), + ] + for client, queue in clients: + await client.discover_services() + service_proxy = client.get_services_by_uuid(service.uuid)[0] + await service_proxy.discover_characteristics() + characteristic_proxy = service_proxy.get_characteristics_by_uuid( + characteristic.uuid + )[0] + + for client, queue in clients[:2]: + characteristic_proxy = service_proxy.get_characteristics_by_uuid( + characteristic.uuid + )[0] + await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=True) + + await devices[1].gatt_server.notify_subscribers(characteristic, b'1234') + for _, queue in clients[:2]: + assert await queue.get() == b'1234' + assert queue.empty() + assert clients[2][1].empty() + + await devices[1].gatt_server.notify_subscriber( + devices.connections[1], characteristic, b'5678' + ) + for _, queue in clients[:2]: + assert await queue.get() == b'5678' + assert queue.empty() + assert clients[2][1].empty() + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_eatt_indicate(): + devices = await TwoDevices.create_with_connection() + devices[1].gatt_server.register_eatt() + + characteristic = Characteristic( + '1234', + Characteristic.Properties.INDICATE, + Characteristic.Permissions.WRITEABLE, + ) + service = Service('ABCD', [characteristic]) + devices[1].add_service(service) + + clients = [ + ( + devices.connections[0].gatt_client, + asyncio.Queue[bytes](), + ), + ( + await gatt_client.Client.connect_eatt(devices.connections[0]), + asyncio.Queue[bytes](), + ), + ( + await gatt_client.Client.connect_eatt(devices.connections[0]), + asyncio.Queue[bytes](), + ), + ] + for client, queue in clients: + await client.discover_services() + service_proxy = client.get_services_by_uuid(service.uuid)[0] + await service_proxy.discover_characteristics() + characteristic_proxy = service_proxy.get_characteristics_by_uuid( + characteristic.uuid + )[0] + + for client, queue in clients[:2]: + characteristic_proxy = service_proxy.get_characteristics_by_uuid( + characteristic.uuid + )[0] + await characteristic_proxy.subscribe(queue.put_nowait, prefer_notify=False) + + await devices[1].gatt_server.indicate_subscribers(characteristic, b'1234') + for _, queue in clients[:2]: + assert await queue.get() == b'1234' + assert queue.empty() + assert clients[2][1].empty() + + await devices[1].gatt_server.indicate_subscriber( + devices.connections[1], characteristic, b'5678' + ) + for _, queue in clients[:2]: + assert await queue.get() == b'5678' + assert queue.empty() + assert clients[2][1].empty() + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_eatt_connection_failure(): + devices = await TwoDevices.create_with_connection() + + with pytest.raises(l2cap.L2capError): + await gatt_client.Client.connect_eatt(devices.connections[0]) + + # ----------------------------------------------------------------------------- if __name__ == '__main__': logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())