diff --git a/bumble/hid.py b/bumble/hid.py index 252c54f5..8f70866b 100644 --- a/bumble/hid.py +++ b/bumble/hid.py @@ -17,17 +17,16 @@ # ----------------------------------------------------------------------------- from __future__ import annotations -import enum +import asyncio import logging import struct from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Callable, Optional +from typing import Callable, ClassVar, Optional, TypeVar from typing_extensions import override -from bumble import device, l2cap, utils -from bumble.core import InvalidStateError, ProtocolError +from bumble import core, device, l2cap, utils from bumble.hci import Address # ----------------------------------------------------------------------------- @@ -44,28 +43,34 @@ HID_INTERRUPT_PSM = 0x0013 -class Message: - message_type: MessageType +class HidProtocolError(core.ProtocolError): + result_code: HandshakeMessage.ResultCode - # Report types - class ReportType(enum.IntEnum): - OTHER_REPORT = 0x00 - INPUT_REPORT = 0x01 - OUTPUT_REPORT = 0x02 - FEATURE_REPORT = 0x03 + def __init__(self, result_code: HandshakeMessage.ResultCode): + self.result_code = result_code + super().__init__( + result_code.value, error_namespace='HID', error_name=result_code.name + ) - # Handshake parameters - class Handshake(enum.IntEnum): - SUCCESSFUL = 0x00 - NOT_READY = 0x01 - ERR_INVALID_REPORT_ID = 0x02 - ERR_UNSUPPORTED_REQUEST = 0x03 - ERR_INVALID_PARAMETER = 0x04 - ERR_UNKNOWN = 0x0E - ERR_FATAL = 0x0F - # Message Type - class MessageType(enum.IntEnum): +# Report types +class ReportType(utils.OpenIntEnum): + OTHER_REPORT = 0x00 + INPUT_REPORT = 0x01 + OUTPUT_REPORT = 0x02 + FEATURE_REPORT = 0x03 + + +# Protocol modes +class ProtocolMode(utils.OpenIntEnum): + BOOT_PROTOCOL = 0x00 + REPORT_PROTOCOL = 0x01 + + +# Messages +class Message: + + class Type(utils.OpenIntEnum): HANDSHAKE = 0x00 CONTROL = 0x01 GET_REPORT = 0x04 @@ -74,497 +79,500 @@ class MessageType(enum.IntEnum): SET_PROTOCOL = 0x07 DATA = 0x0A - # Protocol modes - class ProtocolMode(enum.IntEnum): - BOOT_PROTOCOL = 0x00 - REPORT_PROTOCOL = 0x01 + message_type: Type - # Control Operations - class ControlCommand(enum.IntEnum): - SUSPEND = 0x03 - EXIT_SUSPEND = 0x04 - VIRTUAL_CABLE_UNPLUG = 0x05 + subclasses: ClassVar[dict[Type, type[Message]]] = {} + + _Message = TypeVar('_Message', bound='Message') + + @classmethod + def message(cls, subclass: type[_Message]) -> type[_Message]: + cls.subclasses[subclass.message_type] = subclass + return subclass # Class Method to derive header @classmethod def header(cls, lower_bits: int = 0x00) -> bytes: return bytes([(cls.message_type << 4) | lower_bits]) - -# HIDP messages -@dataclass -class GetReportMessage(Message): - report_type: int - report_id: int - buffer_size: int - message_type = Message.MessageType.GET_REPORT + @classmethod + def from_bytes(cls, data: bytes) -> Message: + message_type = Message.Type(data[0] >> 4) + if subclass := cls.subclasses.get(message_type): + return subclass.from_bytes(data) + else: + raise core.InvalidPacketError(f"Unknown message type {message_type.name}") def __bytes__(self) -> bytes: - packet_bytes = bytearray() - packet_bytes.append(self.report_id) - if self.buffer_size == 0: - return self.header(self.report_type) + packet_bytes - else: - return ( - self.header(0x08 | self.report_type) - + packet_bytes - + struct.pack(" bytes: - return self.header(self.report_type) + self.data +class HandshakeMessage(Message): + message_type = Message.Type.HANDSHAKE + class ResultCode(utils.OpenIntEnum): + SUCCESSFUL = 0x00 + NOT_READY = 0x01 + ERR_INVALID_REPORT_ID = 0x02 + ERR_UNSUPPORTED_REQUEST = 0x03 + ERR_INVALID_PARAMETER = 0x04 + ERR_UNKNOWN = 0x0E + ERR_FATAL = 0x0F -@dataclass -class SendControlData(Message): - report_type: int - data: bytes - message_type = Message.MessageType.DATA + result_code: ResultCode def __bytes__(self) -> bytes: - return self.header(self.report_type) + self.data + return self.header(self.result_code) + + @classmethod + def from_bytes(cls, data: bytes) -> HandshakeMessage: + return cls(result_code=cls.ResultCode(data[0] & 0xFF)) +@Message.message @dataclass -class GetProtocolMessage(Message): - message_type = Message.MessageType.GET_PROTOCOL +class ControlMessage(Message): + message_type = Message.Type.CONTROL + + class Command(utils.OpenIntEnum): + SUSPEND = 0x03 + EXIT_SUSPEND = 0x04 + VIRTUAL_CABLE_UNPLUG = 0x05 + + command: Command def __bytes__(self) -> bytes: - return self.header() + return self.header(self.command) + + @classmethod + def from_bytes(cls, data: bytes) -> ControlMessage: + return cls(command=ControlMessage.Command(data[0] & 0x0F)) +@Message.message @dataclass -class SetProtocolMessage(Message): - protocol_mode: int - message_type = Message.MessageType.SET_PROTOCOL +class GetReportMessage(Message): + message_type = Message.Type.GET_REPORT + FLAG_HAS_SIZE = 0x08 + + report_type: ReportType + report_id: Optional[int] = None + buffer_size: Optional[int] = None def __bytes__(self) -> bytes: - return self.header(self.protocol_mode) + data = self.header( + self.report_type + | (self.FLAG_HAS_SIZE if self.buffer_size is not None else 0) + ) + if self.report_id is not None: + data += bytes([self.report_id]) + if self.buffer_size is not None: + data += struct.pack(" GetReportMessage: + report_type = ReportType(data[0] & 0x03) + if len(data) == 1: + return cls(report_type=report_type) + report_id = data[1] + if data[0] & cls.FLAG_HAS_SIZE: + return cls( + report_type=report_type, + report_id=report_id, + buffer_size=struct.unpack(" bytes: - return self.header(Message.ControlCommand.SUSPEND) + return self.header(self.report_type) + self.data + + @classmethod + def from_bytes(cls, data: bytes) -> SetReportMessage: + return cls(report_type=ReportType(data[0] & 0x03), data=data[1:]) +@Message.message @dataclass -class ExitSuspend(Message): - message_type = Message.MessageType.CONTROL +class GetProtocolMessage(Message): + message_type = Message.Type.GET_PROTOCOL def __bytes__(self) -> bytes: - return self.header(Message.ControlCommand.EXIT_SUSPEND) + return self.header() + + @classmethod + def from_bytes(cls, data: bytes) -> GetProtocolMessage: + del data # unused. + return cls() +@Message.message @dataclass -class VirtualCableUnplug(Message): - message_type = Message.MessageType.CONTROL +class SetProtocolMessage(Message): + message_type = Message.Type.SET_PROTOCOL + + protocol_mode: ProtocolMode def __bytes__(self) -> bytes: - return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG) + return self.header(self.protocol_mode) + + @classmethod + def from_bytes(cls, data: bytes) -> SetProtocolMessage: + return cls(protocol_mode=ProtocolMode(data[0] & 0x01)) # Device sends input report, host sends output report. +@Message.message @dataclass -class SendData(Message): +class DataMessage(Message): + message_type = Message.Type.DATA + data: bytes - report_type: int - message_type = Message.MessageType.DATA + report_type: ReportType def __bytes__(self) -> bytes: return self.header(self.report_type) + self.data - -@dataclass -class SendHandshakeMessage(Message): - result_code: int - message_type = Message.MessageType.HANDSHAKE - - def __bytes__(self) -> bytes: - return self.header(self.result_code) + @classmethod + def from_bytes(cls, data: bytes) -> DataMessage: + return cls(data=data[1:], report_type=ReportType(data[0] & 0x03)) # ----------------------------------------------------------------------------- class HID(ABC, utils.EventEmitter): - l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None - l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None - connection: Optional[device.Connection] = None + control_channel: Optional[l2cap.ClassicChannel] = None + interrupt_channel: Optional[l2cap.ClassicChannel] = None EVENT_INTERRUPT_DATA = "interrupt_data" EVENT_CONTROL_DATA = "control_data" EVENT_SUSPEND = "suspend" EVENT_EXIT_SUSPEND = "exit_suspend" EVENT_VIRTUAL_CABLE_UNPLUG = "virtual_cable_unplug" - EVENT_HANDSHAKE = "handshake" + EVENT_CONNECTION = "connection" + EVENT_DISCONNECTION = "disconnection" - class Role(enum.IntEnum): + class Role(utils.OpenIntEnum): HOST = 0x00 DEVICE = 0x01 - def __init__(self, device: device.Device, role: Role) -> None: + role: ClassVar[Role] + + def __init__(self, device: device.Device) -> None: super().__init__() self.remote_device_bd_address: Optional[Address] = None self.device = device - self.role = role # Register ourselves with the L2CAP channel manager device.create_l2cap_server( - l2cap.ClassicChannelSpec(HID_CONTROL_PSM), self.on_l2cap_connection + l2cap.ClassicChannelSpec(HID_CONTROL_PSM), self._on_l2cap_connection ) device.create_l2cap_server( - l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM), self.on_l2cap_connection + l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM), self._on_l2cap_connection ) - device.on(device.EVENT_CONNECTION, self.on_device_connection) + async def connect(self, connection: device.Connection) -> None: + self.control_channel = await connection.create_l2cap_channel( + l2cap.ClassicChannelSpec(HID_CONTROL_PSM) + ) + self.control_channel.sink = self._on_control_pdu + self.interrupt_channel = await connection.create_l2cap_channel( + l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM) + ) + self.interrupt_channel.sink = self._on_interrupt_pdu - async def connect_control_channel(self) -> None: - if not self.connection: - raise InvalidStateError("Connection is not established!") - # Create a new L2CAP connection - control channel - try: - channel = await self.connection.create_l2cap_channel( - l2cap.ClassicChannelSpec(HID_CONTROL_PSM) - ) - channel.sink = self.on_ctrl_pdu - self.l2cap_ctrl_channel = channel - except ProtocolError: - logging.exception('L2CAP connection failed.') - raise - - async def connect_interrupt_channel(self) -> None: - if not self.connection: - raise InvalidStateError("Connection is not established!") - # Create a new L2CAP connection - interrupt channel - try: - channel = await self.connection.create_l2cap_channel( - l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM) - ) - channel.sink = self.on_intr_pdu - self.l2cap_intr_channel = channel - except ProtocolError: - logging.exception('L2CAP connection failed.') - raise - - async def disconnect_interrupt_channel(self) -> None: - if self.l2cap_intr_channel is None: - raise InvalidStateError('invalid state') - channel = self.l2cap_intr_channel - self.l2cap_intr_channel = None - await channel.disconnect() - - async def disconnect_control_channel(self) -> None: - if self.l2cap_ctrl_channel is None: - raise InvalidStateError('invalid state') - channel = self.l2cap_ctrl_channel - self.l2cap_ctrl_channel = None - await channel.disconnect() - - def on_device_connection(self, connection: device.Connection) -> None: - self.connection = connection - self.remote_device_bd_address = connection.peer_address - connection.on(connection.EVENT_DISCONNECTION, self.on_device_disconnection) - - def on_device_disconnection(self, reason: int) -> None: - self.connection = None - - def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: + async def disconnect(self) -> None: + if self.interrupt_channel: + await self.interrupt_channel.disconnect() + self.interrupt_channel = None + if self.control_channel: + await self.control_channel.disconnect() + self.control_channel = None + + def _on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') l2cap_channel.on( - l2cap_channel.EVENT_OPEN, lambda: self.on_l2cap_channel_open(l2cap_channel) + l2cap_channel.EVENT_OPEN, lambda: self._on_l2cap_channel_open(l2cap_channel) ) l2cap_channel.on( l2cap_channel.EVENT_CLOSE, - lambda: self.on_l2cap_channel_close(l2cap_channel), + lambda: self._on_l2cap_channel_close(l2cap_channel), ) - def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: + def _on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: if l2cap_channel.psm == HID_CONTROL_PSM: - self.l2cap_ctrl_channel = l2cap_channel - self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu + self.control_channel = l2cap_channel + self.control_channel.sink = self._on_control_pdu else: - self.l2cap_intr_channel = l2cap_channel - self.l2cap_intr_channel.sink = self.on_intr_pdu + self.interrupt_channel = l2cap_channel + self.interrupt_channel.sink = self._on_interrupt_pdu + if not self.control_channel: + logger.warning("Interrupt channel established before control channel!") logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') - def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None: + if self.control_channel and self.interrupt_channel: + self.emit(self.EVENT_CONNECTION) + + def _on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None: if l2cap_channel.psm == HID_CONTROL_PSM: - self.l2cap_ctrl_channel = None + self.control_channel = None else: - self.l2cap_intr_channel = None + self.interrupt_channel = None logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}') + if not self.control_channel and not self.interrupt_channel: + self.emit(self.EVENT_DISCONNECTION) + @abstractmethod - def on_ctrl_pdu(self, pdu: bytes) -> None: + def _on_control_pdu(self, pdu: bytes) -> None: pass - def on_intr_pdu(self, pdu: bytes) -> None: - logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') - self.emit(self.EVENT_INTERRUPT_DATA, pdu) + def _on_interrupt_pdu(self, pdu: bytes) -> None: + message = DataMessage.from_bytes(pdu) + logger.debug('<<< [Interrupt] %s', message) + self.emit( + self.EVENT_INTERRUPT_DATA, + message.report_type, + message.data, + ) - def send_pdu_on_ctrl(self, msg: bytes) -> None: - assert self.l2cap_ctrl_channel - self.l2cap_ctrl_channel.send_pdu(msg) + def _send_control_pdu(self, message: Message) -> None: + if not self.control_channel: + raise core.InvalidStateError("Control channel is not connected") + logger.debug('>>> [Control] %s', message) + self.control_channel.send_pdu(message) - def send_pdu_on_intr(self, msg: bytes) -> None: - assert self.l2cap_intr_channel - self.l2cap_intr_channel.send_pdu(msg) + def _send_interrupt_pdu(self, message: Message) -> None: + if not self.interrupt_channel: + raise core.InvalidStateError("Interrupt channel is not connected") + logger.debug('>>> [Interrupt] %s', message) + self.interrupt_channel.send_pdu(message) - def send_data(self, data: bytes) -> None: + def send_interrupt_data(self, data: bytes) -> None: if self.role == HID.Role.HOST: - report_type = Message.ReportType.OUTPUT_REPORT + report_type = ReportType.OUTPUT_REPORT else: - report_type = Message.ReportType.INPUT_REPORT - msg = SendData(data, report_type) - hid_message = bytes(msg) - if self.l2cap_intr_channel is not None: - logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}') - self.send_pdu_on_intr(hid_message) + report_type = ReportType.INPUT_REPORT + if self.interrupt_channel is not None: + self._send_interrupt_pdu(DataMessage(data, report_type)) def virtual_cable_unplug(self) -> None: - msg = VirtualCableUnplug() - hid_message = bytes(msg) - logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) + self._send_control_pdu( + ControlMessage(ControlMessage.Command.VIRTUAL_CABLE_UNPLUG) + ) # ----------------------------------------------------------------------------- class Device(HID): - class GetSetReturn(enum.IntEnum): - FAILURE = 0x00 - REPORT_ID_NOT_FOUND = 0x01 - ERR_UNSUPPORTED_REQUEST = 0x02 - ERR_UNKNOWN = 0x03 - ERR_INVALID_PARAMETER = 0x04 - SUCCESS = 0xFF - @dataclass - class GetSetStatus: - data: bytes = b'' - status: int = 0 + EVENT_PROTOCOL_CHANGED = "protocol_changed" - get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None - set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None - get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None - set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None + class Delegate: + def set_report(self, report_type: ReportType, data: bytes) -> None: + del report_type, data # unused. + raise HidProtocolError(HandshakeMessage.ResultCode.ERR_UNSUPPORTED_REQUEST) - def __init__(self, device: device.Device) -> None: - super().__init__(device, HID.Role.DEVICE) + def get_report( + self, report_type: ReportType, report_id: Optional[int] + ) -> bytes: + del report_type, report_id # unused. + raise HidProtocolError(HandshakeMessage.ResultCode.ERR_UNSUPPORTED_REQUEST) + + role = HID.Role.DEVICE + + def __init__( + self, + device: device.Device, + delegate: Optional[Delegate] = None, + protocol: Optional[ProtocolMode] = None, + ) -> None: + super().__init__(device) + self.delegate = delegate + self.protocol = protocol @override - def on_ctrl_pdu(self, pdu: bytes) -> None: - logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') - param = pdu[0] & 0x0F - message_type = pdu[0] >> 4 - - if message_type == Message.MessageType.GET_REPORT: - logger.debug('<<< HID GET REPORT') - self.handle_get_report(pdu) - elif message_type == Message.MessageType.SET_REPORT: - logger.debug('<<< HID SET REPORT') - self.handle_set_report(pdu) - elif message_type == Message.MessageType.GET_PROTOCOL: - logger.debug('<<< HID GET PROTOCOL') - self.handle_get_protocol(pdu) - elif message_type == Message.MessageType.SET_PROTOCOL: - logger.debug('<<< HID SET PROTOCOL') - self.handle_set_protocol(pdu) - elif message_type == Message.MessageType.DATA: - logger.debug('<<< HID CONTROL DATA') - self.emit(self.EVENT_CONTROL_DATA, pdu) - elif message_type == Message.MessageType.CONTROL: - if param == Message.ControlCommand.SUSPEND: - logger.debug('<<< HID SUSPEND') - self.emit(self.EVENT_SUSPEND) - elif param == Message.ControlCommand.EXIT_SUSPEND: - logger.debug('<<< HID EXIT SUSPEND') - self.emit(self.EVENT_EXIT_SUSPEND) - elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: - logger.debug('<<< HID VIRTUAL CABLE UNPLUG') - self.emit(self.EVENT_VIRTUAL_CABLE_UNPLUG) - else: - logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') - else: - logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) - - def send_handshake_message(self, result_code: int) -> None: - msg = SendHandshakeMessage(result_code) - hid_message = bytes(msg) - logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) - - def send_control_data(self, report_type: int, data: bytes): - msg = SendControlData(report_type=report_type, data=data) - hid_message = bytes(msg) - logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) - - def handle_get_report(self, pdu: bytes): - if self.get_report_cb is None: - logger.debug("GetReport callback not registered !!") - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) - return - report_type = pdu[0] & 0x03 - buffer_flag = (pdu[0] & 0x08) >> 3 - report_id = pdu[1] - logger.debug(f"buffer_flag: {buffer_flag}") - if buffer_flag == 1: - buffer_size = (pdu[3] << 8) | pdu[2] - else: - buffer_size = 0 - - ret = self.get_report_cb(report_id, report_type, buffer_size) - if ret.status == self.GetSetReturn.FAILURE: - self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) - elif ret.status == self.GetSetReturn.SUCCESS: - data = bytearray() - data.append(report_id) - data.extend(ret.data) - if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr] - self.send_control_data(report_type=report_type, data=data) + def _on_control_pdu(self, pdu: bytes) -> None: + message = Message.from_bytes(pdu) + logger.debug('<<< [Control] %s', message) + + try: + if isinstance(message, GetReportMessage): + self._handle_get_report(message) + elif isinstance(message, SetReportMessage): + self._handle_set_report(message) + elif isinstance(message, GetProtocolMessage): + self._handle_get_protocol() + elif isinstance(message, SetProtocolMessage): + self._handle_set_protocol(message) + elif isinstance(message, DataMessage): + self.emit(self.EVENT_CONTROL_DATA, message) + elif isinstance(message, ControlMessage): + if message.command == ControlMessage.Command.SUSPEND: + self.emit(self.EVENT_SUSPEND) + elif message.command == ControlMessage.Command.EXIT_SUSPEND: + self.emit(self.EVENT_EXIT_SUSPEND) + elif message.command == ControlMessage.Command.VIRTUAL_CABLE_UNPLUG: + self.emit(self.EVENT_VIRTUAL_CABLE_UNPLUG) + else: + logger.error('Unsupported command %s', message.command.name) else: - self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) - elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: - self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) - elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: - self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) - elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) - - def register_get_report_cb( - self, cb: Callable[[int, int, int], Device.GetSetStatus] - ) -> None: - self.get_report_cb = cb - logger.debug("GetReport callback registered successfully") + logger.error('Unsupported command type %s', message.message_type.name) + self._send_handshake_message( + HandshakeMessage.ResultCode.ERR_UNSUPPORTED_REQUEST + ) + except NotImplementedError: + self._send_handshake_message( + HandshakeMessage.ResultCode.ERR_UNSUPPORTED_REQUEST + ) + except HidProtocolError as e: + self._send_handshake_message(e.result_code) - def handle_set_report(self, pdu: bytes): - if self.set_report_cb is None: - logger.debug("SetReport callback not registered !!") - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) - return - report_type = pdu[0] & 0x03 - report_id = pdu[1] - report_data = pdu[2:] - report_size = len(report_data) + 1 - ret = self.set_report_cb(report_id, report_type, report_size, report_data) - if ret.status == self.GetSetReturn.SUCCESS: - self.send_handshake_message(Message.Handshake.SUCCESSFUL) - elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: - self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) - elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND: - self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) - else: - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + def _send_handshake_message(self, result_code: HandshakeMessage.ResultCode) -> None: + self._send_control_pdu(HandshakeMessage(result_code)) - def register_set_report_cb( - self, cb: Callable[[int, int, int, bytes], Device.GetSetStatus] - ) -> None: - self.set_report_cb = cb - logger.debug("SetReport callback registered successfully") + def _send_control_data(self, report_type: ReportType, data: bytes): + self._send_control_pdu(DataMessage(report_type=report_type, data=data)) - def handle_get_protocol(self, pdu: bytes): - if self.get_protocol_cb is None: - logger.debug("GetProtocol callback not registered !!") - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + def _handle_get_report(self, message: GetReportMessage) -> None: + if not self.delegate: + self._send_handshake_message( + HandshakeMessage.ResultCode.ERR_UNSUPPORTED_REQUEST + ) return - ret = self.get_protocol_cb() - if ret.status == self.GetSetReturn.SUCCESS: - self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) - else: - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + result = self.delegate.get_report(message.report_type, message.report_id) + data = ( + bytes(([message.report_id] if message.report_id is not None else [])) + + result + ) - def register_get_protocol_cb(self, cb: Callable[[], Device.GetSetStatus]) -> None: - self.get_protocol_cb = cb - logger.debug("GetProtocol callback registered successfully") + assert self.control_channel + if len(data) < self.control_channel.peer_mtu: + self._send_control_data(report_type=message.report_type, data=data) + else: + self._send_handshake_message( + HandshakeMessage.ResultCode.ERR_INVALID_PARAMETER + ) - def handle_set_protocol(self, pdu: bytes): - if self.set_protocol_cb is None: - logger.debug("SetProtocol callback not registered !!") - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + def _handle_set_report(self, message: SetReportMessage): + if not self.delegate: + self._send_handshake_message( + HandshakeMessage.ResultCode.ERR_UNSUPPORTED_REQUEST + ) return - ret = self.set_protocol_cb(pdu[0] & 0x01) - if ret.status == self.GetSetReturn.SUCCESS: - self.send_handshake_message(Message.Handshake.SUCCESSFUL) + self.delegate.set_report(message.report_type, message.data) + self._send_handshake_message(HandshakeMessage.ResultCode.SUCCESSFUL) + + def _handle_get_protocol(self): + if self.protocol is None: + self._send_handshake_message( + HandshakeMessage.ResultCode.ERR_UNSUPPORTED_REQUEST + ) else: - self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) + self._send_control_data(ReportType.OTHER_REPORT, bytes([self.protocol])) - def register_set_protocol_cb( - self, cb: Callable[[int], Device.GetSetStatus] - ) -> None: - self.set_protocol_cb = cb - logger.debug("SetProtocol callback registered successfully") + def _handle_set_protocol(self, message: SetProtocolMessage): + if self.protocol is None: + self._send_handshake_message( + HandshakeMessage.ResultCode.ERR_UNSUPPORTED_REQUEST + ) + else: + self.protocol = message.protocol_mode + self._send_handshake_message(HandshakeMessage.ResultCode.SUCCESSFUL) + self.emit(self.EVENT_PROTOCOL_CHANGED) # ----------------------------------------------------------------------------- class Host(HID): + role = HID.Role.HOST + + _pending_command_future: Optional[asyncio.Future[Optional[DataMessage]]] = None + def __init__(self, device: device.Device) -> None: - super().__init__(device, HID.Role.HOST) + super().__init__(device) + self._report_queue = asyncio.Queue[bytes] + + async def _send_control_message(self, message: Message) -> Optional[DataMessage]: + self._pending_command_future = asyncio.get_running_loop().create_future() + self._send_control_pdu(message) + return await self._pending_command_future + + async def get_report( + self, + report_type: ReportType, + report_id: Optional[int] = None, + buffer_size: Optional[int] = None, + ) -> bytes: + result = await self._send_control_message( + GetReportMessage( + report_type=report_type, report_id=report_id, buffer_size=buffer_size + ) + ) + if result: + return result.data + else: + raise core.UnreachableError() + + async def set_report(self, report_type: ReportType, data: bytes) -> None: + await self._send_control_message( + SetReportMessage(report_type=report_type, data=data) + ) - def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: - msg = GetReportMessage( - report_type=report_type, report_id=report_id, buffer_size=buffer_size + async def get_protocol(self) -> ProtocolMode: + result = await self._send_control_message(GetProtocolMessage()) + if result: + return ProtocolMode(result.data[0]) + else: + raise core.UnreachableError() + + async def set_protocol(self, protocol_mode: ProtocolMode) -> None: + await self._send_control_message( + SetProtocolMessage(protocol_mode=protocol_mode) ) - hid_message = bytes(msg) - logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) - - def set_report(self, report_type: int, data: bytes) -> None: - msg = SetReportMessage(report_type=report_type, data=data) - hid_message = bytes(msg) - logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) - - def get_protocol(self) -> None: - msg = GetProtocolMessage() - hid_message = bytes(msg) - logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) - - def set_protocol(self, protocol_mode: int) -> None: - msg = SetProtocolMessage(protocol_mode=protocol_mode) - hid_message = bytes(msg) - logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) def suspend(self) -> None: - msg = Suspend() - hid_message = bytes(msg) - logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) + self._send_control_pdu(ControlMessage(ControlMessage.Command.SUSPEND)) def exit_suspend(self) -> None: - msg = ExitSuspend() - hid_message = bytes(msg) - logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}') - self.send_pdu_on_ctrl(hid_message) + self._send_control_pdu(ControlMessage(ControlMessage.Command.EXIT_SUSPEND)) @override - def on_ctrl_pdu(self, pdu: bytes) -> None: - logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') - param = pdu[0] & 0x0F - message_type = pdu[0] >> 4 - if message_type == Message.MessageType.HANDSHAKE: - logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}') - self.emit(self.EVENT_HANDSHAKE, Message.Handshake(param)) - elif message_type == Message.MessageType.DATA: - logger.debug('<<< HID CONTROL DATA') - self.emit(self.EVENT_CONTROL_DATA, pdu) - elif message_type == Message.MessageType.CONTROL: - if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: - logger.debug('<<< HID VIRTUAL CABLE UNPLUG') + def _on_control_pdu(self, pdu: bytes) -> None: + message = Message.from_bytes(pdu) + logger.debug('<<< [Control] %s', message) + if isinstance(message, DataMessage): + if self._pending_command_future and not self._pending_command_future.done(): + self._pending_command_future.set_result(message) + self._pending_command_future = None + else: + logger.error('Unexpected message %s', message) + elif isinstance(message, HandshakeMessage): + if self._pending_command_future and not self._pending_command_future.done(): + if message.result_code == HandshakeMessage.ResultCode.SUCCESSFUL: + self._pending_command_future.set_result(None) + else: + self._pending_command_future.set_exception( + HidProtocolError(message.result_code) + ) + self._pending_command_future = None + else: + logger.error('Unexpected message %s', message) + elif isinstance(message, ControlMessage): + if message.command == ControlMessage.Command.VIRTUAL_CABLE_UNPLUG: self.emit(self.EVENT_VIRTUAL_CABLE_UNPLUG) else: - logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') + logger.debug('Unsupported command %s', message.command.name) else: - logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') + logger.debug('Unsupported message %s', message.message_type.name) diff --git a/examples/keyboard.py b/examples/keyboard.py index e93de2c3..b033b933 100644 --- a/examples/keyboard.py +++ b/examples/keyboard.py @@ -16,14 +16,14 @@ # Imports # ----------------------------------------------------------------------------- import asyncio +import functools import json -import struct import sys import websockets.asyncio.server import bumble.logging -from bumble import data_types +from bumble import data_types, gatt_client from bumble.colors import color from bumble.core import AdvertisingData from bumble.device import Connection, Device, Peer @@ -148,17 +148,17 @@ async def on_disconnection(self, reason): # ----------------------------------------------------------------------------- -def on_hid_control_point_write(_connection, value): - print(f'Control Point Write: {value}') +def on_hid_control_point_write(_connection: Connection, value: bytes): + print(f'Control Point Write: {value.hex()}') # ----------------------------------------------------------------------------- -def on_report(characteristic, value): +def on_report(characteristic: gatt_client.CharacteristicProxy, value: bytes): print(color('Report:', 'cyan'), value.hex(), 'from', characteristic) # ----------------------------------------------------------------------------- -async def keyboard_host(device, peer_address): +async def keyboard_host(device: Device, peer_address: str): await device.power_on() connection = await device.connect(peer_address) await connection.pair() @@ -221,10 +221,7 @@ async def keyboard_host(device, peer_address): else: report_reference = bytes([0, 0]) await peer.subscribe( - characteristic, - lambda value, param=f'[{i}] {report_reference.hex()}': on_report( - param, value - ), + characteristic, functools.partial(on_report, characteristic) ) protocol_mode = await peer.read_value(protocol_mode_characteristic) @@ -238,7 +235,7 @@ async def keyboard_host(device, peer_address): # ----------------------------------------------------------------------------- -async def keyboard_device(device, command): +async def keyboard_device(device: Device, command: str): # Create an 'input report' characteristic to send keyboard reports to the host input_report_characteristic = Characteristic( GATT_REPORT_CHARACTERISTIC, diff --git a/examples/run_hid_device.py b/examples/run_hid_device.py deleted file mode 100644 index 887c493c..00000000 --- a/examples/run_hid_device.py +++ /dev/null @@ -1,745 +0,0 @@ -# Copyright 2021-2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- -import asyncio -import json -import struct -import sys - -import websockets.asyncio.server - -import bumble.logging -from bumble.core import ( - BT_HIDP_PROTOCOL_ID, - BT_HUMAN_INTERFACE_DEVICE_SERVICE, - BT_L2CAP_PROTOCOL_ID, - PhysicalTransport, -) -from bumble.device import Device -from bumble.hid import HID_CONTROL_PSM, HID_INTERRUPT_PSM -from bumble.hid import Device as HID_Device -from bumble.hid import Message -from bumble.sdp import ( - SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, - SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID, - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_PUBLIC_BROWSE_ROOT, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, - DataElement, - ServiceAttribute, -) -from bumble.transport import open_transport - -# ----------------------------------------------------------------------------- -# SDP attributes for Bluetooth HID devices -SDP_HID_SERVICE_NAME_ATTRIBUTE_ID = 0x0100 -SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID = 0x0101 -SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID = 0x0102 -SDP_HID_DEVICE_RELEASE_NUMBER_ATTRIBUTE_ID = 0x0200 # [DEPRECATED] -SDP_HID_PARSER_VERSION_ATTRIBUTE_ID = 0x0201 -SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID = 0x0202 -SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID = 0x0203 -SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID = 0x0204 -SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID = 0x0205 -SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0x0206 -SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID = 0x0207 -SDP_HID_SDP_DISABLE_ATTRIBUTE_ID = 0x0208 # [DEPRECATED] -SDP_HID_BATTERY_POWER_ATTRIBUTE_ID = 0x0209 -SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID = 0x020A -SDP_HID_PROFILE_VERSION_ATTRIBUTE_ID = 0x020B # DEPRECATED] -SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID = 0x020C -SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID = 0x020D -SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID = 0x020E -SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID = 0x020F -SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID = 0x0210 - -# Refer to HID profile specification v1.1.1, "5.3 Service Discovery Protocol (SDP)" for details -# HID SDP attribute values -LANGUAGE = 0x656E # 0x656E uint16 “en” (English) -ENCODING = 0x6A # 0x006A uint16 UTF-8 encoding -PRIMARY_LANGUAGE_BASE_ID = 0x100 # 0x0100 uint16 PrimaryLanguageBaseID -VERSION_NUMBER = 0x0101 # 0x0101 uint16 version number (v1.1) -SERVICE_NAME = b'Bumble HID' -SERVICE_DESCRIPTION = b'Bumble' -PROVIDER_NAME = b'Bumble' -HID_PARSER_VERSION = 0x0111 # uint16 0x0111 (v1.1.1) -HID_DEVICE_SUBCLASS = 0xC0 # Combo keyboard/pointing device -HID_COUNTRY_CODE = 0x21 # 0x21 Uint8, USA -HID_VIRTUAL_CABLE = True # Virtual cable enabled -HID_RECONNECT_INITIATE = True # Reconnect initiate enabled -REPORT_DESCRIPTOR_TYPE = 0x22 # 0x22 Type = Report Descriptor -HID_LANGID_BASE_LANGUAGE = 0x0409 # 0x0409 Language = English (United States) -HID_LANGID_BASE_BLUETOOTH_STRING_OFFSET = 0x100 # 0x0100 Default -HID_BATTERY_POWER = True # Battery power enabled -HID_REMOTE_WAKE = True # Remote wake enabled -HID_SUPERVISION_TIMEOUT = 0xC80 # uint16 0xC80 (2s) -HID_NORMALLY_CONNECTABLE = True # Normally connectable enabled -HID_BOOT_DEVICE = True # Boot device support enabled -HID_SSR_HOST_MAX_LATENCY = 0x640 # uint16 0x640 (1s) -HID_SSR_HOST_MIN_TIMEOUT = 0xC80 # uint16 0xC80 (2s) -HID_REPORT_MAP = bytes( # Text String, 50 Octet Report Descriptor - # pylint: disable=line-too-long - [ - 0x05, - 0x01, # Usage Page (Generic Desktop Ctrls) - 0x09, - 0x06, # Usage (Keyboard) - 0xA1, - 0x01, # Collection (Application) - 0x85, - 0x01, # . Report ID (1) - 0x05, - 0x07, # . Usage Page (Kbrd/Keypad) - 0x19, - 0xE0, # . Usage Minimum (0xE0) - 0x29, - 0xE7, # . Usage Maximum (0xE7) - 0x15, - 0x00, # . Logical Minimum (0) - 0x25, - 0x01, # . Logical Maximum (1) - 0x75, - 0x01, # . Report Size (1) - 0x95, - 0x08, # . Report Count (8) - 0x81, - 0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position) - 0x95, - 0x01, # . Report Count (1) - 0x75, - 0x08, # . Report Size (8) - 0x81, - 0x03, # . Input (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position) - 0x95, - 0x05, # . Report Count (5) - 0x75, - 0x01, # . Report Size (1) - 0x05, - 0x08, # . Usage Page (LEDs) - 0x19, - 0x01, # . Usage Minimum (Num Lock) - 0x29, - 0x05, # . Usage Maximum (Kana) - 0x91, - 0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile) - 0x95, - 0x01, # . Report Count (1) - 0x75, - 0x03, # . Report Size (3) - 0x91, - 0x03, # . Output (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile) - 0x95, - 0x06, # . Report Count (6) - 0x75, - 0x08, # . Report Size (8) - 0x15, - 0x00, # . Logical Minimum (0) - 0x25, - 0x65, # . Logical Maximum (101) - 0x05, - 0x07, # . Usage Page (Kbrd/Keypad) - 0x19, - 0x00, # . Usage Minimum (0x00) - 0x29, - 0x65, # . Usage Maximum (0x65) - 0x81, - 0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position) - 0xC0, # End Collection - 0x05, - 0x01, # Usage Page (Generic Desktop Ctrls) - 0x09, - 0x02, # Usage (Mouse) - 0xA1, - 0x01, # Collection (Application) - 0x85, - 0x02, # . Report ID (2) - 0x09, - 0x01, # . Usage (Pointer) - 0xA1, - 0x00, # . Collection (Physical) - 0x05, - 0x09, # . Usage Page (Button) - 0x19, - 0x01, # . Usage Minimum (0x01) - 0x29, - 0x03, # . Usage Maximum (0x03) - 0x15, - 0x00, # . Logical Minimum (0) - 0x25, - 0x01, # . Logical Maximum (1) - 0x95, - 0x03, # . Report Count (3) - 0x75, - 0x01, # . Report Size (1) - 0x81, - 0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position) - 0x95, - 0x01, # . Report Count (1) - 0x75, - 0x05, # . Report Size (5) - 0x81, - 0x03, # . Input (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position) - 0x05, - 0x01, # . Usage Page (Generic Desktop Ctrls) - 0x09, - 0x30, # . Usage (X) - 0x09, - 0x31, # . Usage (Y) - 0x15, - 0x81, # . Logical Minimum (-127) - 0x25, - 0x7F, # . Logical Maximum (127) - 0x75, - 0x08, # . Report Size (8) - 0x95, - 0x02, # . Report Count (2) - 0x81, - 0x06, # . Input (Data,Var,Rel,No Wrap,Linear,Preferred State,No Null Position) - 0xC0, # . End Collection - 0xC0, # End Collection - ] -) - - -# Default protocol mode set to report protocol -protocol_mode = Message.ProtocolMode.REPORT_PROTOCOL - - -# ----------------------------------------------------------------------------- -def sdp_records(): - service_record_handle = 0x00010002 - return { - service_record_handle: [ - ServiceAttribute( - SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, - DataElement.unsigned_integer_32(service_record_handle), - ), - ServiceAttribute( - SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, - DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), - ), - ServiceAttribute( - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - DataElement.sequence( - [DataElement.uuid(BT_HUMAN_INTERFACE_DEVICE_SERVICE)] - ), - ), - ServiceAttribute( - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - DataElement.sequence( - [ - DataElement.sequence( - [ - DataElement.uuid(BT_L2CAP_PROTOCOL_ID), - DataElement.unsigned_integer_16(HID_CONTROL_PSM), - ] - ), - DataElement.sequence( - [ - DataElement.uuid(BT_HIDP_PROTOCOL_ID), - ] - ), - ] - ), - ), - ServiceAttribute( - SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID, - DataElement.sequence( - [ - DataElement.unsigned_integer_16(LANGUAGE), - DataElement.unsigned_integer_16(ENCODING), - DataElement.unsigned_integer_16(PRIMARY_LANGUAGE_BASE_ID), - ] - ), - ), - ServiceAttribute( - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - DataElement.sequence( - [ - DataElement.sequence( - [ - DataElement.uuid(BT_HUMAN_INTERFACE_DEVICE_SERVICE), - DataElement.unsigned_integer_16(VERSION_NUMBER), - ] - ), - ] - ), - ), - ServiceAttribute( - SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - DataElement.sequence( - [ - DataElement.sequence( - [ - DataElement.sequence( - [ - DataElement.uuid(BT_L2CAP_PROTOCOL_ID), - DataElement.unsigned_integer_16( - HID_INTERRUPT_PSM - ), - ] - ), - DataElement.sequence( - [ - DataElement.uuid(BT_HIDP_PROTOCOL_ID), - ] - ), - ] - ), - ] - ), - ), - ServiceAttribute( - SDP_HID_SERVICE_NAME_ATTRIBUTE_ID, - DataElement(DataElement.TEXT_STRING, SERVICE_NAME), - ), - ServiceAttribute( - SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID, - DataElement(DataElement.TEXT_STRING, SERVICE_DESCRIPTION), - ), - ServiceAttribute( - SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID, - DataElement(DataElement.TEXT_STRING, PROVIDER_NAME), - ), - ServiceAttribute( - SDP_HID_PARSER_VERSION_ATTRIBUTE_ID, - DataElement.unsigned_integer_32(HID_PARSER_VERSION), - ), - ServiceAttribute( - SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID, - DataElement.unsigned_integer_32(HID_DEVICE_SUBCLASS), - ), - ServiceAttribute( - SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID, - DataElement.unsigned_integer_32(HID_COUNTRY_CODE), - ), - ServiceAttribute( - SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID, - DataElement.boolean(HID_VIRTUAL_CABLE), - ), - ServiceAttribute( - SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID, - DataElement.boolean(HID_RECONNECT_INITIATE), - ), - ServiceAttribute( - SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID, - DataElement.sequence( - [ - DataElement.sequence( - [ - DataElement.unsigned_integer_16(REPORT_DESCRIPTOR_TYPE), - DataElement(DataElement.TEXT_STRING, HID_REPORT_MAP), - ] - ), - ] - ), - ), - ServiceAttribute( - SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID, - DataElement.sequence( - [ - DataElement.sequence( - [ - DataElement.unsigned_integer_16( - HID_LANGID_BASE_LANGUAGE - ), - DataElement.unsigned_integer_16( - HID_LANGID_BASE_BLUETOOTH_STRING_OFFSET - ), - ] - ), - ] - ), - ), - ServiceAttribute( - SDP_HID_BATTERY_POWER_ATTRIBUTE_ID, - DataElement.boolean(HID_BATTERY_POWER), - ), - ServiceAttribute( - SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID, - DataElement.boolean(HID_REMOTE_WAKE), - ), - ServiceAttribute( - SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID, - DataElement.unsigned_integer_16(HID_SUPERVISION_TIMEOUT), - ), - ServiceAttribute( - SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID, - DataElement.boolean(HID_NORMALLY_CONNECTABLE), - ), - ServiceAttribute( - SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID, - DataElement.boolean(HID_BOOT_DEVICE), - ), - ServiceAttribute( - SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID, - DataElement.unsigned_integer_16(HID_SSR_HOST_MAX_LATENCY), - ), - ServiceAttribute( - SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID, - DataElement.unsigned_integer_16(HID_SSR_HOST_MIN_TIMEOUT), - ), - ] - } - - -# ----------------------------------------------------------------------------- -async def get_stream_reader(pipe) -> asyncio.StreamReader: - loop = asyncio.get_event_loop() - reader = asyncio.StreamReader(loop=loop) - protocol = asyncio.StreamReaderProtocol(reader) - await loop.connect_read_pipe(lambda: protocol, pipe) - return reader - - -class DeviceData: - def __init__(self) -> None: - self.keyboardData = bytearray( - [0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] - ) - self.mouseData = bytearray([0x02, 0x00, 0x00, 0x00]) - - -# Device's live data - Mouse and Keyboard will be stored in this -deviceData = DeviceData() - - -# ----------------------------------------------------------------------------- -async def keyboard_device(hid_device: HID_Device): - - # Start a Websocket server to receive events from a web page - async def serve(websocket: websockets.asyncio.server.ServerConnection): - global deviceData - while True: - try: - message = await websocket.recv() - print('Received: ', str(message)) - parsed = json.loads(message) - message_type = parsed['type'] - if message_type == 'keydown': - # Only deal with keys a to z for now - key = parsed['key'] - if len(key) == 1: - code = ord(key) - if ord('a') <= code <= ord('z'): - hid_code = 0x04 + code - ord('a') - deviceData.keyboardData = bytearray( - [ - 0x01, - 0x00, - 0x00, - hid_code, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - ] - ) - hid_device.send_data(deviceData.keyboardData) - elif message_type == 'keyup': - deviceData.keyboardData = bytearray( - [0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] - ) - hid_device.send_data(deviceData.keyboardData) - elif message_type == "mousemove": - # logical min and max values - log_min = -127 - log_max = 127 - x = parsed['x'] - y = parsed['y'] - # limiting x and y values within logical max and min range - x = max(log_min, min(log_max, x)) - y = max(log_min, min(log_max, y)) - deviceData.mouseData = bytearray([0x02, 0x00]) + struct.pack( - ">bb", x, y - ) - hid_device.send_data(deviceData.mouseData) - except websockets.exceptions.ConnectionClosedOK: - pass - - # pylint: disable-next=no-member - await websockets.asyncio.server.serve(serve, 'localhost', 8989) - await asyncio.get_event_loop().create_future() - - -# ----------------------------------------------------------------------------- -async def main() -> None: - if len(sys.argv) < 3: - print( - 'Usage: python run_hid_device.py ' - ' where is one of:\n' - ' test-mode (run with menu enabled for testing)\n' - ' web (run a keyboard with keypress input from a web page, ' - 'see keyboard.html' - ) - print('example: python run_hid_device.py hid_keyboard.json usb:0 web') - print('example: python run_hid_device.py hid_keyboard.json usb:0 test-mode') - - return - - async def handle_virtual_cable_unplug(): - hid_host_bd_addr = str(hid_device.remote_device_bd_address) - await hid_device.disconnect_interrupt_channel() - await hid_device.disconnect_control_channel() - await device.keystore.delete(hid_host_bd_addr) # type: ignore - connection = hid_device.connection - if connection is not None: - await connection.disconnect() - - def on_hid_data_cb(pdu: bytes): - print(f'Received Data, PDU: {pdu.hex()}') - - def on_get_report_cb( - report_id: int, report_type: int, buffer_size: int - ) -> HID_Device.GetSetStatus: - retValue = hid_device.GetSetStatus() - print( - "GET_REPORT report_id: " - + str(report_id) - + "report_type: " - + str(report_type) - + "buffer_size:" - + str(buffer_size) - ) - if report_type == Message.ReportType.INPUT_REPORT: - if report_id == 1: - retValue.data = deviceData.keyboardData[1:] - retValue.status = hid_device.GetSetReturn.SUCCESS - elif report_id == 2: - retValue.data = deviceData.mouseData[1:] - retValue.status = hid_device.GetSetReturn.SUCCESS - else: - retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND - - if buffer_size: - data_len = buffer_size - 1 - retValue.data = retValue.data[:data_len] - elif report_type == Message.ReportType.OUTPUT_REPORT: - # This sample app has nothing to do with the report received, to enable PTS - # testing, we will return single byte random data. - retValue.data = bytearray([0x11]) - retValue.status = hid_device.GetSetReturn.SUCCESS - elif report_type == Message.ReportType.FEATURE_REPORT: - retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER - elif report_type == Message.ReportType.OTHER_REPORT: - if report_id == 3: - retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND - else: - retValue.status = hid_device.GetSetReturn.FAILURE - - return retValue - - def on_set_report_cb( - report_id: int, report_type: int, report_size: int, data: bytes - ) -> HID_Device.GetSetStatus: - print( - "SET_REPORT report_id: " - + str(report_id) - + "report_type: " - + str(report_type) - + "report_size " - + str(report_size) - + "data:" - + str(data) - ) - if report_type == Message.ReportType.FEATURE_REPORT: - status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER - elif report_type == Message.ReportType.INPUT_REPORT: - if report_id == 1 and report_size != len(deviceData.keyboardData): - status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER - elif report_id == 2 and report_size != len(deviceData.mouseData): - status = HID_Device.GetSetReturn.ERR_INVALID_PARAMETER - elif report_id == 3: - status = HID_Device.GetSetReturn.REPORT_ID_NOT_FOUND - else: - status = HID_Device.GetSetReturn.SUCCESS - else: - status = HID_Device.GetSetReturn.SUCCESS - - return HID_Device.GetSetStatus(status=status) - - def on_get_protocol_cb() -> HID_Device.GetSetStatus: - return HID_Device.GetSetStatus( - data=bytes([protocol_mode]), - status=hid_device.GetSetReturn.SUCCESS, - ) - - def on_set_protocol_cb(protocol: int) -> HID_Device.GetSetStatus: - # We do not support SET_PROTOCOL. - print(f"SET_PROTOCOL report_id: {protocol}") - return HID_Device.GetSetStatus( - status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST - ) - - def on_virtual_cable_unplug_cb(): - print('Received Virtual Cable Unplug') - asyncio.create_task(handle_virtual_cable_unplug()) - - print('<<< connecting to HCI...') - async with await open_transport(sys.argv[2]) as hci_transport: - print('<<< connected') - - # Create a device - device = Device.from_config_file_with_hci( - sys.argv[1], hci_transport.source, hci_transport.sink - ) - device.classic_enabled = True - - # Create and register HID device - hid_device = HID_Device(device) - - # Register for call backs - hid_device.on('interrupt_data', on_hid_data_cb) - - hid_device.register_get_report_cb(on_get_report_cb) - hid_device.register_set_report_cb(on_set_report_cb) - hid_device.register_get_protocol_cb(on_get_protocol_cb) - hid_device.register_set_protocol_cb(on_set_protocol_cb) - - # Register for virtual cable unplug call back - hid_device.on('virtual_cable_unplug', on_virtual_cable_unplug_cb) - - # Setup the SDP to advertise HID Device service - device.sdp_service_records = sdp_records() - - # Start the controller - await device.power_on() - - # Start being discoverable and connectable - await device.set_discoverable(True) - await device.set_connectable(True) - - async def menu(): - reader = await get_stream_reader(sys.stdin) - while True: - print( - "\n************************ HID Device Menu *****************************\n" - ) - print(" 1. Connect Control Channel") - print(" 2. Connect Interrupt Channel") - print(" 3. Disconnect Control Channel") - print(" 4. Disconnect Interrupt Channel") - print(" 5. Send Report on Interrupt Channel") - print(" 6. Virtual Cable Unplug") - print(" 7. Disconnect device") - print(" 8. Delete Bonding") - print(" 9. Re-connect to device") - print("10. Exit ") - print("\nEnter your choice : \n") - - choice = await reader.readline() - choice = choice.decode('utf-8').strip() - - if choice == '1': - await hid_device.connect_control_channel() - - elif choice == '2': - await hid_device.connect_interrupt_channel() - - elif choice == '3': - await hid_device.disconnect_control_channel() - - elif choice == '4': - await hid_device.disconnect_interrupt_channel() - - elif choice == '5': - print(" 1. Report ID 0x01") - print(" 2. Report ID 0x02") - print(" 3. Invalid Report ID") - - choice1 = await reader.readline() - choice1 = choice1.decode('utf-8').strip() - - if choice1 == '1': - data = bytearray( - [0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00] - ) - hid_device.send_data(data) - data = bytearray( - [0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] - ) - hid_device.send_data(data) - - elif choice1 == '2': - data = bytearray([0x02, 0x00, 0x00, 0xF6]) - hid_device.send_data(data) - data = bytearray([0x02, 0x00, 0x00, 0x00]) - hid_device.send_data(data) - - elif choice1 == '3': - data = bytearray([0x00, 0x00, 0x00, 0x00]) - hid_device.send_data(data) - data = bytearray([0x00, 0x00, 0x00, 0x00]) - hid_device.send_data(data) - - else: - print('Incorrect option selected') - - elif choice == '6': - hid_device.virtual_cable_unplug() - try: - hid_host_bd_addr = str(hid_device.remote_device_bd_address) - await device.keystore.delete(hid_host_bd_addr) - except KeyError: - print('Device not found or Device already unpaired.') - - elif choice == '7': - connection = hid_device.connection - if connection is not None: - await connection.disconnect() - else: - print("Already disconnected from device") - - elif choice == '8': - try: - hid_host_bd_addr = str(hid_device.remote_device_bd_address) - await device.keystore.delete(hid_host_bd_addr) - except KeyError: - print('Device NOT found or Device already unpaired.') - - elif choice == '9': - hid_host_bd_addr = str(hid_device.remote_device_bd_address) - connection = await device.connect( - hid_host_bd_addr, transport=PhysicalTransport.BR_EDR - ) - await connection.authenticate() - await connection.encrypt() - - elif choice == '10': - sys.exit("Exit successful") - - else: - print("Invalid option selected.") - - if (len(sys.argv) > 3) and (sys.argv[3] == 'test-mode'): - # Test mode for PTS/Unit testing - await menu() - else: - # default option is using keyboard.html (web) - print("Executing in Web mode") - await keyboard_device(hid_device) - - await hci_transport.source.wait_for_termination() - - -# ----------------------------------------------------------------------------- -bumble.logging.setup_basic_logging('DEBUG') -asyncio.run(main()) diff --git a/examples/run_hid_host.py b/examples/run_hid_host.py deleted file mode 100644 index 2691c924..00000000 --- a/examples/run_hid_host.py +++ /dev/null @@ -1,565 +0,0 @@ -# Copyright 2021-2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- -import asyncio -import sys - -from hid_report_parser import ReportParser - -import bumble.logging -from bumble.colors import color -from bumble.core import BT_HUMAN_INTERFACE_DEVICE_SERVICE, PhysicalTransport -from bumble.device import Device -from bumble.hci import Address -from bumble.hid import Host, Message -from bumble.sdp import ( - SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_ALL_ATTRIBUTES_RANGE, - SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, - SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID, - SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, - SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, - SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, -) -from bumble.sdp import Client as SDP_Client -from bumble.transport import open_transport - -# ----------------------------------------------------------------------------- -# SDP attributes for Bluetooth HID devices -SDP_HID_SERVICE_NAME_ATTRIBUTE_ID = 0x0100 -SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID = 0x0101 -SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID = 0x0102 -SDP_HID_DEVICE_RELEASE_NUMBER_ATTRIBUTE_ID = 0x0200 # [DEPRECATED] -SDP_HID_PARSER_VERSION_ATTRIBUTE_ID = 0x0201 -SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID = 0x0202 -SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID = 0x0203 -SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID = 0x0204 -SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID = 0x0205 -SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0x0206 -SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID = 0x0207 -SDP_HID_SDP_DISABLE_ATTRIBUTE_ID = 0x0208 # [DEPRECATED] -SDP_HID_BATTERY_POWER_ATTRIBUTE_ID = 0x0209 -SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID = 0x020A -SDP_HID_PROFILE_VERSION_ATTRIBUTE_ID = 0x020B # DEPRECATED] -SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID = 0x020C -SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID = 0x020D -SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID = 0x020E -SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID = 0x020F -SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID = 0x0210 - - -# ----------------------------------------------------------------------------- - - -async def get_hid_device_sdp_record(connection): - - # Connect to the SDP Server - sdp_client = SDP_Client(connection) - await sdp_client.connect() - if sdp_client: - print(color('Connected to SDP Server', 'blue')) - else: - print(color('Failed to connect to SDP Server', 'red')) - - # List BT HID Device service in the root browse group - service_record_handles = await sdp_client.search_services( - [BT_HUMAN_INTERFACE_DEVICE_SERVICE] - ) - - if len(service_record_handles) < 1: - await sdp_client.disconnect() - raise Exception( - color(f'BT HID Device service not found on peer device!!!!', 'red') - ) - - # For BT_HUMAN_INTERFACE_DEVICE_SERVICE service, get all its attributes - for service_record_handle in service_record_handles: - attributes = await sdp_client.get_attributes( - service_record_handle, [SDP_ALL_ATTRIBUTES_RANGE] - ) - print(color(f'SERVICE {service_record_handle:04X} attributes:', 'yellow')) - print(color(f'SDP attributes for HID device', 'magenta')) - for attribute in attributes: - if attribute.id == SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID: - print( - color(' Service Record Handle : ', 'cyan'), - hex(attribute.value.value), - ) - - elif attribute.id == SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID: - print( - color(' Service Class : ', 'cyan'), attribute.value.value[0].value - ) - - elif attribute.id == SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID: - print( - color(' SDP Browse Group List : ', 'cyan'), - attribute.value.value[0].value, - ) - - elif attribute.id == SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: - print( - color(' BT_L2CAP_PROTOCOL_ID : ', 'cyan'), - attribute.value.value[0].value[0].value, - ) - print( - color(' PSM for Bluetooth HID Control channel : ', 'cyan'), - hex(attribute.value.value[0].value[1].value), - ) - print( - color(' BT_HIDP_PROTOCOL_ID : ', 'cyan'), - attribute.value.value[1].value[0].value, - ) - - elif attribute.id == SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID: - print( - color(' Lanugage : ', 'cyan'), hex(attribute.value.value[0].value) - ) - print( - color(' Encoding : ', 'cyan'), hex(attribute.value.value[1].value) - ) - print( - color(' PrimaryLanguageBaseID : ', 'cyan'), - hex(attribute.value.value[2].value), - ) - - elif attribute.id == SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID: - print( - color(' BT_HUMAN_INTERFACE_DEVICE_SERVICE ', 'cyan'), - attribute.value.value[0].value[0].value, - ) - print( - color(' HID Profileversion number : ', 'cyan'), - hex(attribute.value.value[0].value[1].value), - ) - - elif attribute.id == SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: - print( - color(' BT_L2CAP_PROTOCOL_ID : ', 'cyan'), - attribute.value.value[0].value[0].value[0].value, - ) - print( - color(' PSM for Bluetooth HID Interrupt channel : ', 'cyan'), - hex(attribute.value.value[0].value[0].value[1].value), - ) - print( - color(' BT_HIDP_PROTOCOL_ID : ', 'cyan'), - attribute.value.value[0].value[1].value[0].value, - ) - - elif attribute.id == SDP_HID_SERVICE_NAME_ATTRIBUTE_ID: - print(color(' Service Name: ', 'cyan'), attribute.value.value) - - elif attribute.id == SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID: - print(color(' Service Description: ', 'cyan'), attribute.value.value) - - elif attribute.id == SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID: - print(color(' Provider Name: ', 'cyan'), attribute.value.value) - - elif attribute.id == SDP_HID_DEVICE_RELEASE_NUMBER_ATTRIBUTE_ID: - print(color(' Release Number: ', 'cyan'), hex(attribute.value.value)) - - elif attribute.id == SDP_HID_PARSER_VERSION_ATTRIBUTE_ID: - print( - color(' HID Parser Version: ', 'cyan'), hex(attribute.value.value) - ) - - elif attribute.id == SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID: - print( - color(' HIDDeviceSubclass: ', 'cyan'), hex(attribute.value.value) - ) - - elif attribute.id == SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID: - print(color(' HIDCountryCode: ', 'cyan'), hex(attribute.value.value)) - - elif attribute.id == SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID: - print(color(' HIDVirtualCable: ', 'cyan'), attribute.value.value) - - elif attribute.id == SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID: - print(color(' HIDReconnectInitiate: ', 'cyan'), attribute.value.value) - - elif attribute.id == SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID: - print( - color(' HID Report Descriptor type: ', 'cyan'), - hex(attribute.value.value[0].value[0].value), - ) - print( - color(' HID Report DescriptorList: ', 'cyan'), - attribute.value.value[0].value[1].value, - ) - - elif attribute.id == SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID: - print( - color(' HID LANGID Base Language: ', 'cyan'), - hex(attribute.value.value[0].value[0].value), - ) - print( - color(' HID LANGID Base Bluetooth String Offset: ', 'cyan'), - hex(attribute.value.value[0].value[1].value), - ) - - elif attribute.id == SDP_HID_BATTERY_POWER_ATTRIBUTE_ID: - print(color(' HIDBatteryPower: ', 'cyan'), attribute.value.value) - - elif attribute.id == SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID: - print(color(' HIDRemoteWake: ', 'cyan'), attribute.value.value) - - elif attribute.id == SDP_HID_PROFILE_VERSION_ATTRIBUTE_ID: - print( - color(' HIDProfileVersion : ', 'cyan'), hex(attribute.value.value) - ) - - elif attribute.id == SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID: - print( - color(' HIDSupervisionTimeout: ', 'cyan'), - hex(attribute.value.value), - ) - - elif attribute.id == SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID: - print( - color(' HIDNormallyConnectable: ', 'cyan'), attribute.value.value - ) - - elif attribute.id == SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID: - print(color(' HIDBootDevice: ', 'cyan'), attribute.value.value) - - elif attribute.id == SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID: - print( - color(' HIDSSRHostMaxLatency: ', 'cyan'), - hex(attribute.value.value), - ) - - elif attribute.id == SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID: - print( - color(' HIDSSRHostMinTimeout: ', 'cyan'), - hex(attribute.value.value), - ) - - else: - print( - color( - f' Warning: Attribute ID: {attribute.id} match not found.\n Attribute Info: {attribute}', - 'yellow', - ) - ) - - await sdp_client.disconnect() - - -# ----------------------------------------------------------------------------- -async def get_stream_reader(pipe) -> asyncio.StreamReader: - loop = asyncio.get_event_loop() - reader = asyncio.StreamReader(loop=loop) - protocol = asyncio.StreamReaderProtocol(reader) - await loop.connect_read_pipe(lambda: protocol, pipe) - return reader - - -# ----------------------------------------------------------------------------- -async def main() -> None: - if len(sys.argv) < 4: - print( - 'Usage: run_hid_host.py ' - ' [test-mode]' - ) - - print('example: run_hid_host.py classic1.json usb:0 E1:CA:72:48:C4:E8/P') - return - - def on_hid_control_data_cb(pdu: bytes): - print(f'Received Control Data, PDU: {pdu.hex()}') - - def on_hid_interrupt_data_cb(pdu: bytes): - report_type = pdu[0] & 0x0F - if len(pdu) == 1: - print(color(f'Warning: No report received', 'yellow')) - return - report_length = len(pdu[1:]) - report_id = pdu[1] - - if report_type != Message.ReportType.OTHER_REPORT: - print( - color( - f' Report type = {report_type}, Report length = {report_length}, Report id = {report_id}', - 'blue', - None, - 'bold', - ) - ) - - if (report_length <= 1) or (report_id == 0): - return - # Parse report over interrupt channel - if report_type == Message.ReportType.INPUT_REPORT: - ReportParser.parse_input_report(pdu[1:]) # type: ignore - - async def handle_virtual_cable_unplug(): - await hid_host.disconnect_interrupt_channel() - await hid_host.disconnect_control_channel() - await device.keystore.delete(target_address) # type: ignore - connection = hid_host.connection - if connection is not None: - await connection.disconnect() - - def on_hid_virtual_cable_unplug_cb(): - asyncio.create_task(handle_virtual_cable_unplug()) - - print('<<< connecting to HCI...') - async with await open_transport(sys.argv[2]) as hci_transport: - print('<<< CONNECTED') - - # Create a device - device = Device.from_config_file_with_hci( - sys.argv[1], hci_transport.source, hci_transport.sink - ) - device.classic_enabled = True - - # Create HID host and start it - print('@@@ Starting HID Host...') - hid_host = Host(device) - - # Register for HID data call back - hid_host.on('interrupt_data', on_hid_interrupt_data_cb) - hid_host.on('control_data', on_hid_control_data_cb) - - # Register for virtual cable unplug call back - hid_host.on('virtual_cable_unplug', on_hid_virtual_cable_unplug_cb) - - await device.power_on() - - # Connect to a peer - target_address = sys.argv[3] - print(f'=== Connecting to {target_address}...') - connection = await device.connect( - target_address, transport=PhysicalTransport.BR_EDR - ) - print(f'=== Connected to {connection.peer_address}!') - - # Request authentication - print('*** Authenticating...') - await connection.authenticate() - print('*** Authenticated...') - - # Enable encryption - print('*** Enabling encryption...') - await connection.encrypt() - print('*** Encryption on') - - await get_hid_device_sdp_record(connection) - - async def menu(): - reader = await get_stream_reader(sys.stdin) - while True: - print( - "\n************************ HID Host Menu *****************************\n" - ) - print(" 1. Connect Control Channel") - print(" 2. Connect Interrupt Channel") - print(" 3. Disconnect Control Channel") - print(" 4. Disconnect Interrupt Channel") - print(" 5. Get Report") - print(" 6. Set Report") - print(" 7. Set Protocol Mode") - print(" 8. Get Protocol Mode") - print(" 9. Send Report on Interrupt Channel") - print("10. Suspend") - print("11. Exit Suspend") - print("12. Virtual Cable Unplug") - print("13. Disconnect device") - print("14. Delete Bonding") - print("15. Re-connect to device") - print("16. Exit") - print("\nEnter your choice : \n") - - choice = await reader.readline() - choice = choice.decode('utf-8').strip() - - if choice == '1': - await hid_host.connect_control_channel() - - elif choice == '2': - await hid_host.connect_interrupt_channel() - - elif choice == '3': - await hid_host.disconnect_control_channel() - - elif choice == '4': - await hid_host.disconnect_interrupt_channel() - - elif choice == '5': - print(" 1. Input Report with ID 0x01") - print(" 2. Input Report with ID 0x02") - print(" 3. Input Report with ID 0x0F - Invalid ReportId") - print(" 4. Output Report with ID 0x02") - print(" 5. Feature Report with ID 0x05 - Unsupported Request") - print(" 6. Input Report with ID 0x02, BufferSize 3") - print(" 7. Output Report with ID 0x03, BufferSize 2") - print(" 8. Feature Report with ID 0x05, BufferSize 3") - choice1 = await reader.readline() - choice1 = choice1.decode('utf-8').strip() - - if choice1 == '1': - hid_host.get_report(1, 1, 0) - - elif choice1 == '2': - hid_host.get_report(1, 2, 0) - - elif choice1 == '3': - hid_host.get_report(1, 5, 0) - - elif choice1 == '4': - hid_host.get_report(2, 2, 0) - - elif choice1 == '5': - hid_host.get_report(3, 15, 0) - - elif choice1 == '6': - hid_host.get_report(1, 2, 3) - - elif choice1 == '7': - hid_host.get_report(2, 3, 2) - - elif choice1 == '8': - hid_host.get_report(3, 5, 3) - else: - print('Incorrect option selected') - - elif choice == '6': - print(" 1. Report type 1 and Report id 0x01") - print(" 2. Report type 2 and Report id 0x03") - print(" 3. Report type 3 and Report id 0x05") - choice1 = await reader.readline() - choice1 = choice1.decode('utf-8').strip() - - if choice1 == '1': - # data includes first octet as report id - data = bytearray( - [0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01] - ) - hid_host.set_report(1, data) - - elif choice1 == '2': - data = bytearray([0x03, 0x01, 0x01]) - hid_host.set_report(2, data) - - elif choice1 == '3': - data = bytearray([0x05, 0x01, 0x01, 0x01]) - hid_host.set_report(3, data) - - else: - print('Incorrect option selected') - - elif choice == '7': - print(" 0. Boot") - print(" 1. Report") - choice1 = await reader.readline() - choice1 = choice1.decode('utf-8').strip() - - if choice1 == '0': - hid_host.set_protocol(Message.ProtocolMode.BOOT_PROTOCOL) - - elif choice1 == '1': - hid_host.set_protocol(Message.ProtocolMode.REPORT_PROTOCOL) - - else: - print('Incorrect option selected') - - elif choice == '8': - hid_host.get_protocol() - - elif choice == '9': - print(" 1. Report ID 0x01") - print(" 2. Report ID 0x03") - choice1 = await reader.readline() - choice1 = choice1.decode('utf-8').strip() - - if choice1 == '1': - data = bytearray( - [0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] - ) - hid_host.send_data(data) - - elif choice1 == '2': - data = bytearray([0x03, 0x00, 0x0D, 0xFD, 0x00, 0x00]) - hid_host.send_data(data) - - else: - print('Incorrect option selected') - - elif choice == '10': - hid_host.suspend() - - elif choice == '11': - hid_host.exit_suspend() - - elif choice == '12': - hid_host.virtual_cable_unplug() - try: - await device.keystore.delete(target_address) - print("Unpair successful") - except KeyError: - print('Device not found or Device already unpaired.') - - elif choice == '13': - peer_address = Address.from_string_for_transport( - target_address, transport=PhysicalTransport.BR_EDR - ) - connection = device.find_connection_by_bd_addr( - peer_address, transport=PhysicalTransport.BR_EDR - ) - if connection is not None: - await connection.disconnect() - else: - print("Already disconnected from device") - - elif choice == '14': - try: - await device.keystore.delete(target_address) - print("Unpair successful") - except KeyError: - print('Device not found or Device already unpaired.') - - elif choice == '15': - connection = await device.connect( - target_address, transport=PhysicalTransport.BR_EDR - ) - await connection.authenticate() - await connection.encrypt() - - elif choice == '16': - sys.exit("Exit successful") - - else: - print("Invalid option selected.") - - if (len(sys.argv) > 4) and (sys.argv[4] == 'test-mode'): - # Enabling menu for testing - await menu() - else: - # HID Connection - # Control channel - await hid_host.connect_control_channel() - # Interrupt Channel - await hid_host.connect_interrupt_channel() - - await hci_transport.source.wait_for_termination() - - -# ----------------------------------------------------------------------------- -bumble.logging.setup_basic_logging('DEBUG') -asyncio.run(main()) diff --git a/tests/hid_test.py b/tests/hid_test.py new file mode 100644 index 00000000..23a973cf --- /dev/null +++ b/tests/hid_test.py @@ -0,0 +1,173 @@ +# Copyright 2021-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations + +import asyncio +import logging + +import pytest + +from bumble import hid + +from . import test_utils + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +async def hid_protocols() -> tuple[hid.Host, hid.Device]: + devices = await test_utils.TwoDevices.create_with_connection() + host = hid.Host(devices[0]) + device = hid.Device(devices[1]) + assert devices.connections[0] + await host.connect(devices.connections[0]) + return host, device + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_connection(): + devices = await test_utils.TwoDevices.create_with_connection() + host = hid.Host(devices[0]) + device = hid.Device(devices[1]) + + connected = asyncio.Event() + device.on(device.EVENT_CONNECTION, lambda: connected.set()) + await host.connect(devices.connections[0]) + await connected.wait() + + disconnected = asyncio.Event() + device.on(device.EVENT_DISCONNECTION, lambda: disconnected.set()) + await host.disconnect() + await disconnected.wait() + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_device_send_interrupt_data(): + host, device = await hid_protocols() + queue = asyncio.Queue[tuple[hid.ReportType, bytes]]() + + @host.on(host.EVENT_INTERRUPT_DATA) + def _(report_type: hid.ReportType, data: bytes): + queue.put_nowait((report_type, data)) + + device.send_interrupt_data(b'123') + assert (await queue.get()) == (hid.ReportType.INPUT_REPORT, b'123') + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_host_send_interrupt_data(): + host, device = await hid_protocols() + queue = asyncio.Queue[tuple[hid.ReportType, bytes]]() + + @device.on(device.EVENT_INTERRUPT_DATA) + def _(report_type: hid.ReportType, data: bytes): + queue.put_nowait((report_type, data)) + + host.send_interrupt_data(b'123') + assert (await queue.get()) == (hid.ReportType.OUTPUT_REPORT, b'123') + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_device_virtual_cable_unplug(): + host, device = await hid_protocols() + unplugged = asyncio.Event() + host.on(host.EVENT_VIRTUAL_CABLE_UNPLUG, lambda: unplugged.set()) + + device.virtual_cable_unplug() + await unplugged.wait() + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_host_virtual_cable_unplug(): + host, device = await hid_protocols() + unplugged = asyncio.Event() + device.on(device.EVENT_VIRTUAL_CABLE_UNPLUG, lambda: unplugged.set()) + + host.virtual_cable_unplug() + await unplugged.wait() + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_set_get_protocol(): + host, device = await hid_protocols() + + device.protocol = hid.ProtocolMode.BOOT_PROTOCOL + assert await host.get_protocol() == hid.ProtocolMode.BOOT_PROTOCOL + + await host.set_protocol(hid.ProtocolMode.REPORT_PROTOCOL) + assert await host.get_protocol() == hid.ProtocolMode.REPORT_PROTOCOL + + device.protocol = None + with pytest.raises(hid.HidProtocolError): + await host.get_protocol() + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_set_get_report(): + host, device = await hid_protocols() + + class Delegate(hid.Device.Delegate): + def __init__(self): + super().__init__() + self.reports = {} + + def get_report( + self, report_type: hid.ReportType, report_id: int | None + ) -> bytes: + return self.reports[report_type] + + def set_report(self, report_type: hid.ReportType, data: bytes) -> None: + self.reports[report_type] = data + + device.delegate = Delegate() + device.delegate.reports[hid.ReportType.INPUT_REPORT] = b'123' + + assert await host.get_report(hid.ReportType.INPUT_REPORT) == b'123' + + await host.set_report(hid.ReportType.OUTPUT_REPORT, b'456') + assert await host.get_report(hid.ReportType.OUTPUT_REPORT) == b'456' + + device.delegate = None + with pytest.raises(hid.HidProtocolError): + await host.get_report(hid.ReportType.INPUT_REPORT) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_suspend_resume(): + host, device = await hid_protocols() + + suspended = asyncio.Event() + device.on(device.EVENT_SUSPEND, lambda: suspended.set()) + host.suspend() + await suspended.wait() + + resumed = asyncio.Event() + device.on(device.EVENT_EXIT_SUSPEND, lambda: resumed.set()) + host.exit_suspend() + await resumed.wait()