diff --git a/pylabrobot/io/usb.py b/pylabrobot/io/usb.py index 6865e7001ab..f9d1c1afd7a 100644 --- a/pylabrobot/io/usb.py +++ b/pylabrobot/io/usb.py @@ -124,38 +124,49 @@ async def write(self, data: bytes, timeout: Optional[float] = None): ) logger.log(LOG_LEVEL_IO, "%s write: %s", self._unique_id, data) capturer.record( - USBCommand(device_id=self._unique_id, action="write", data=data.decode("unicode_escape")) + USBCommand( + device_id=self._unique_id, + action="write", + data=data.decode("unicode_escape", errors="backslashreplace"), + ) ) - def _read_packet(self) -> Optional[bytearray]: + def _read_packet(self, size: Optional[int] = None) -> Optional[bytearray]: """Read a packet from the machine. + Args: + size: The maximum number of bytes to read. If `None`, read up to wMaxPacketSize bytes. + Returns: - A string containing the decoded packet, or None if no packet was received. + A bytearray containing the data read, or None if no data was received. """ assert self.dev is not None and self.read_endpoint is not None, "Device not connected." + read_size = size if size is not None else self.read_endpoint.wMaxPacketSize + try: res = self.dev.read( self.read_endpoint, - self.read_endpoint.wMaxPacketSize, + read_size, timeout=int(self.packet_read_timeout * 1000), # timeout in ms ) if res is not None: - return bytearray(res) # convert res into text + return bytearray(res) return None except usb.core.USBError: # No data available (yet), this will give a timeout error. Don't reraise. return None - async def read(self, timeout: Optional[int] = None) -> bytes: + async def read(self, timeout: Optional[int] = None, size: Optional[int] = None) -> bytes: """Read a response from the device. Args: timeout: The timeout for reading from the device in seconds. If `None`, use the default timeout (specified by the `read_timeout` attribute). + size: The maximum number of bytes to read. If `None`, read all available data until no + more packets arrive. """ assert self.read_endpoint is not None, "Device not connected." @@ -173,20 +184,27 @@ def read_or_timeout(): resp = bytearray() last_packet: Optional[bytearray] = None while True: # read while we have data, and while the last packet is the max size. - last_packet = self._read_packet() + remaining = size - len(resp) if size is not None else None + last_packet = self._read_packet(size=remaining) if last_packet is not None: resp += last_packet if self.read_endpoint is None: raise RuntimeError("Read endpoint is None. Call setup() first.") if last_packet is None or len(last_packet) != self.read_endpoint.wMaxPacketSize: break + if size is not None and len(resp) >= size: + break if len(resp) == 0: continue logger.log(LOG_LEVEL_IO, "%s read: %s", self._unique_id, resp) capturer.record( - USBCommand(device_id=self._unique_id, action="read", data=resp.decode("unicode_escape")) + USBCommand( + device_id=self._unique_id, + action="read", + data=resp.decode("unicode_escape", errors="backslashreplace"), + ) ) return resp @@ -420,11 +438,12 @@ async def write(self, data: bytes, timeout: Optional[float] = None): and next_command.action == "write" ): raise ValidationError("next command is not write") - if not next_command.data == data.decode("unicode_escape"): - align_sequences(expected=next_command.data, actual=data.decode("unicode_escape")) + decoded = data.decode("unicode_escape", errors="backslashreplace") + if not next_command.data == decoded: + align_sequences(expected=next_command.data, actual=decoded) raise ValidationError("Data mismatch: difference was written to stdout.") - async def read(self, timeout: Optional[float] = None) -> bytes: + async def read(self, timeout: Optional[float] = None, size: Optional[int] = None) -> bytes: next_command = USBCommand(**self.cr.next_command()) if not ( next_command.module == "usb" @@ -432,7 +451,10 @@ async def read(self, timeout: Optional[float] = None) -> bytes: and next_command.action == "read" ): raise ValidationError("next command is not read") - return next_command.data.encode() + data = next_command.data.encode() + if size is not None: + data = data[:size] + return data def ctrl_transfer( self,