Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/lea_unicast/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ async def run(self) -> None:
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
eatt_enabled=True,
)

device_config.le_enabled = True
Expand Down
66 changes: 63 additions & 3 deletions bumble/att.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@
TYPE_CHECKING,
ClassVar,
Generic,
TypeAlias,
TypeVar,
)

from bumble import hci, utils
from typing_extensions import TypeIs

from bumble import hci, l2cap, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.hci import HCI_Object
Expand All @@ -50,6 +53,14 @@

_T = TypeVar('_T')

Bearer: TypeAlias = "Connection | l2cap.LeCreditBasedChannel"
EnhancedBearer: TypeAlias = l2cap.LeCreditBasedChannel


def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]:
return isinstance(bearer, EnhancedBearer)


# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
Expand All @@ -58,6 +69,7 @@

ATT_CID = 0x04
ATT_PSM = 0x001F
EATT_PSM = 0x0027

class Opcode(hci.SpecableEnum):
ATT_ERROR_RESPONSE = 0x01
Expand Down Expand Up @@ -780,6 +792,32 @@ def write(self, connection: Connection, value: _T) -> Awaitable[None] | None:
return self._write(connection, value)


# -----------------------------------------------------------------------------
class AttributeValueV2(Generic[_T]):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a docstring/comment here to explain why/when this alternate class is needed (i.e only needed when the full bearer info is needed for the read and write callbacks), but otherwise works the same as AttributeValue).

def __init__(
self,
read: Callable[[Bearer], Awaitable[_T]] | Callable[[Bearer], _T] | None = None,
write: (
Callable[[Bearer, _T], Awaitable[None]]
| Callable[[Bearer, _T], None]
| None
) = None,
):
self._read = read
self._write = write

def read(self, bearer: Bearer) -> _T | Awaitable[_T]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(bearer)

def write(self, bearer: Bearer, value: _T) -> Awaitable[None] | None:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(bearer, value)


# -----------------------------------------------------------------------------
class Attribute(utils.EventEmitter, Generic[_T]):
class Permissions(enum.IntFlag):
Expand Down Expand Up @@ -855,7 +893,8 @@ def encode_value(self, value: _T) -> bytes:
def decode_value(self, value: bytes) -> _T:
return value # type: ignore

async def read_value(self, connection: Connection) -> bytes:
async def read_value(self, bearer: Bearer) -> bytes:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
Expand Down Expand Up @@ -890,14 +929,26 @@ async def read_value(self, connection: Connection) -> bytes:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value

self.emit(self.EVENT_READ, connection, b'' if value is None else value)

return b'' if value is None else self.encode_value(value)

async def write_value(self, connection: Connection, value: bytes) -> None:
async def write_value(self, bearer: Bearer, value: bytes) -> None:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
if (
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
and connection is not None
Expand Down Expand Up @@ -931,6 +982,15 @@ async def write_value(self, connection: Connection, value: bytes) -> None:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = decoded_value

Expand Down
39 changes: 22 additions & 17 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from typing_extensions import Self

from bumble import (
att,
core,
data_types,
gatt,
Expand All @@ -53,7 +54,6 @@
smp,
utils,
)
from bumble.att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
from bumble.colors import color
from bumble.core import (
AdvertisingData,
Expand Down Expand Up @@ -1743,7 +1743,6 @@ class Connection(utils.CompositeEventEmitter):
EVENT_CONNECTION_PARAMETERS_UPDATE_FAILURE = "connection_parameters_update_failure"
EVENT_CONNECTION_PHY_UPDATE = "connection_phy_update"
EVENT_CONNECTION_PHY_UPDATE_FAILURE = "connection_phy_update_failure"
EVENT_CONNECTION_ATT_MTU_UPDATE = "connection_att_mtu_update"
EVENT_CONNECTION_DATA_LENGTH_CHANGE = "connection_data_length_change"
EVENT_CHANNEL_SOUNDING_CAPABILITIES_FAILURE = (
"channel_sounding_capabilities_failure"
Expand Down Expand Up @@ -1846,7 +1845,7 @@ def __init__(
self.encryption_key_size = 0
self.authenticated = False
self.sc = False
self.att_mtu = ATT_DEFAULT_MTU
self.att_mtu = att.ATT_DEFAULT_MTU
self.data_length = DEVICE_DEFAULT_DATA_LENGTH
self.gatt_client = gatt_client.Client(self) # Per-connection client
self.gatt_server = (
Expand Down Expand Up @@ -1996,6 +1995,15 @@ async def get_remote_le_features(self) -> hci.LeFeatureMask:
self.peer_le_features = await self.device.get_remote_le_features(self)
return self.peer_le_features

def on_att_mtu_update(self, mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{self.handle:04X}] '
f'{self.peer_address} as {self.role_name}, '
f'{mtu}'
)
self.att_mtu = mtu
self.emit(self.EVENT_CONNECTION_ATT_MTU_UPDATE)

@property
def data_packet_queue(self) -> DataPacketQueue | None:
return self.device.host.get_data_packet_queue(self.handle)
Expand Down Expand Up @@ -2079,6 +2087,7 @@ class DeviceConfiguration:
l2cap.L2CAP_Information_Request.ExtendedFeatures.FCS_OPTION,
l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE,
)
eatt_enabled: bool = False

def __post_init__(self) -> None:
self.gatt_services: list[dict[str, Any]] = []
Expand Down Expand Up @@ -2497,7 +2506,10 @@ def __init__(
add_gap_service=config.gap_service_enabled,
add_gatt_service=config.gatt_service_enabled,
)
self.l2cap_channel_manager.register_fixed_channel(ATT_CID, self.on_gatt_pdu)
self.l2cap_channel_manager.register_fixed_channel(att.ATT_CID, self.on_gatt_pdu)

if self.config.eatt_enabled:
self.gatt_server.register_eatt()

# Forward some events
utils.setup_event_forwarding(
Expand Down Expand Up @@ -5140,7 +5152,11 @@ def add_default_services(
if add_gap_service:
self.gatt_server.add_service(GenericAccessService(self.name))
if add_gatt_service:
self.gatt_service = gatt_service.GenericAttributeProfileService()
self.gatt_service = gatt_service.GenericAttributeProfileService(
gatt.ServerSupportedFeatures.EATT_SUPPORTED
if self.config.eatt_enabled
else None
)
self.gatt_server.add_service(self.gatt_service)

async def notify_subscriber(
Expand Down Expand Up @@ -6240,17 +6256,6 @@ def on_le_subrate_change(
)
connection.emit(connection.EVENT_LE_SUBRATE_CHANGE)

@host_event_handler
@with_connection_from_handle
def on_connection_att_mtu_update(self, connection: Connection, att_mtu: int):
logger.debug(
f'*** Connection ATT MTU Update: [0x{connection.handle:04X}] '
f'{connection.peer_address} as {connection.role_name}, '
f'{att_mtu}'
)
connection.att_mtu = att_mtu
connection.emit(connection.EVENT_CONNECTION_ATT_MTU_UPDATE)

@host_event_handler
@with_connection_from_handle
def on_connection_data_length_change(
Expand Down Expand Up @@ -6437,7 +6442,7 @@ def on_pairing_failure(self, connection: Connection, reason: int) -> None:
@with_connection_from_handle
def on_gatt_pdu(self, connection: Connection, pdu: bytes):
# Parse the L2CAP payload into an ATT PDU object
att_pdu = ATT_PDU.from_bytes(pdu)
att_pdu = att.ATT_PDU.from_bytes(pdu)

# Conveniently, even-numbered op codes are client->server and
# odd-numbered ones are server->client
Expand Down
4 changes: 2 additions & 2 deletions bumble/gatt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from collections.abc import Iterable, Sequence
from typing import TypeVar

from bumble.att import Attribute, AttributeValue
from bumble.att import Attribute, AttributeValue, AttributeValueV2
from bumble.colors import color
from bumble.core import UUID, BaseBumbleError

Expand Down Expand Up @@ -579,7 +579,7 @@ class Descriptor(Attribute):
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
elif isinstance(self.value, (AttributeValue, AttributeValueV2)):
value_str = '<dynamic>'
else:
value_str = '<...>'
Expand Down
Loading
Loading