Skip to content

Commit

Permalink
connections.ev3: Improve firmware update.
Browse files Browse the repository at this point in the history
- Allow faster flashing of smaller firmwares by erasing only as much as needed.
- Work around USB3.0 issues with the EV3 bootloader.
  • Loading branch information
laurensvalk committed Oct 31, 2024
1 parent 5e2a361 commit 81f310d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 30 deletions.
22 changes: 13 additions & 9 deletions pybricksdev/cli/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,24 +382,28 @@ async def flash_ev3(firmware: bytes) -> None:
fw, hw = await bootloader.get_version()
print(f"hwid: {hw}")

ERASE_TICKS = 60

# Erasing doesn't have any feedback so we just use time for the progress
# bar. The operation runs on the EV3, so the time is the same for everyone.
async def tick(callback):
for _ in range(ERASE_TICKS):
await asyncio.sleep(1)
callback(1)
CHUNK = 8000
SPEED = 256000
for _ in range(len(firmware) // CHUNK):
await asyncio.sleep(CHUNK / SPEED)
callback(CHUNK)

print("Erasing memory...")
with logging_redirect_tqdm(), tqdm(total=ERASE_TICKS) as pbar:
await asyncio.gather(bootloader.erase_chip(), tick(pbar.update))
print("Erasing memory and preparing firmware download...")
with logging_redirect_tqdm(), tqdm(
total=len(firmware), unit="B", unit_scale=True
) as pbar:
await asyncio.gather(
bootloader.erase_and_begin_download(len(firmware)), tick(pbar.update)
)

print("Downloading firmware...")
with logging_redirect_tqdm(), tqdm(
total=len(firmware), unit="B", unit_scale=True
) as pbar:
await bootloader.download(0, firmware, pbar.update)
await bootloader.download(firmware, pbar.update)

print("Verifying...", end="", flush=True)
checksum = await bootloader.get_checksum(0, len(firmware))
Expand Down
60 changes: 39 additions & 21 deletions pybricksdev/connections/ev3.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,20 @@ def _send_command(self, command: Command, payload: Optional[bytes] = None) -> in

return message_number

def _receive_reply(self, command: Command, message_number: int) -> bytes:
def _receive_reply(
self, command: Command, message_number: int, force_length: int = 0
) -> bytes:
"""
Receive a reply from the EV3 bootloader.
Args:
command: The command that was sent.
message_number: The return value of :meth:`_send_command`.
force_length: Expected length, used only when it fails to unpack
normally. Some replies on USB 3.0 hosts contain
the original command written over the reply. This
means the header is bad, but the payload may be in
tact if you know what data to expect.
Returns:
The payload of the reply.
Expand All @@ -131,36 +138,41 @@ def _receive_reply(self, command: Command, message_number: int) -> bytes:
raise ReplyError(status)

if message_type != MessageType.SYSTEM_REPLY:
raise RuntimeError("unexpected message type: {message_type}")
if force_length:
return reply[7 : force_length + 2]
raise RuntimeError(f"unexpected message type: {message_type}")

if reply_command != command:
raise RuntimeError("command mismatch: {reply_command} != {command}")
raise RuntimeError(f"command mismatch: {reply_command} != {command}")

return reply[7 : length + 2]

def download_sync(
self,
address: int,
data: bytes,
progress: Optional[Callable[[int], None]] = None,
) -> None:
"""
Blocking version of :meth:`download`.
"""
param_data = struct.pack("<II", address, len(data))
num = self._send_command(Command.BEGIN_DOWNLOAD, param_data)
self._receive_reply(Command.BEGIN_DOWNLOAD, num)

completed = 0
for c in chunk(data, self._MAX_DATA_SIZE):
num = self._send_command(Command.DOWNLOAD_DATA, c)
self._receive_reply(Command.DOWNLOAD_DATA, num)
try:
completed += len(c)
self._receive_reply(Command.DOWNLOAD_DATA, num)
except RuntimeError as e:
# Allow exception only on the final chunk.
if completed != len(data):
raise e
print(e, ". Proceeding anyway.")

if progress:
progress(len(c))

async def download(
self,
address: int,
data: bytes,
progress: Optional[Callable[[int], None]] = None,
) -> None:
Expand All @@ -170,30 +182,31 @@ async def download(
This operation takes about 60 seconds for a full 16MB firmware file.
Args:
address: The starting address of where to write the data.
data: The data to write.
progress: Optional callback for indicating progress.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.download_sync, address, data, progress
None, self.download_sync, data, progress
)

def erase_chip_sync(self) -> None:
def erase_and_begin_download_sync(self, size) -> None:
"""
Blocking version of :meth:`erase_chip`.
Blocking version of :meth:`erase_and_begin_download`.
"""
num = self._send_command(Command.CHIP_ERASE)
self._receive_reply(Command.CHIP_ERASE, num)
param_data = struct.pack("<II", 0, size)
num = self._send_command(Command.BEGIN_DOWNLOAD_WITH_ERASE, param_data)
self._receive_reply(Command.BEGIN_DOWNLOAD_WITH_ERASE, num)

async def erase_chip(self) -> None:
async def erase_and_begin_download(self, size) -> None:
"""
Erases the external flash memory chip.
Erases the external flash memory chip by the amount required to
flash the new firmware. Also prepares firmware download.
This operation takes about 60 seconds.
Args:
size: How much to erase.
"""
return await asyncio.get_running_loop().run_in_executor(
None,
self.erase_chip_sync,
None, self.erase_and_begin_download_sync, size
)

def start_app_sync(self) -> None:
Expand Down Expand Up @@ -241,7 +254,12 @@ def get_version_sync(self) -> Tuple[int, int]:
Blocking version of :meth:`get_version`.
"""
num = self._send_command(Command.GET_VERSION)
payload = self._receive_reply(Command.GET_VERSION, num)
# On certain USB 3.0 systems, the brick reply contains the command
# we just sent written over it. This means we don't get the correct
# header and length info. Since the command here is smaller than the
# reply, the paypload does not get overwritten, so we can still get
# the version info since we know the expected reply size.
payload = self._receive_reply(Command.GET_VERSION, num, force_length=13)
return struct.unpack("<II", payload)

async def get_version(self) -> Tuple[int, int]:
Expand Down

0 comments on commit 81f310d

Please sign in to comment.