b484fcf9ed
- New communication protocol for downloaded chunks based on queues instead of events.
342 lines
12 KiB
Python
342 lines
12 KiB
Python
import enum
|
|
import logging
|
|
import queue
|
|
import os
|
|
import re
|
|
import threading
|
|
import websocket
|
|
from typing import Optional, Union
|
|
|
|
from platypush.utils import grouper
|
|
|
|
|
|
class Connection:
|
|
"""
|
|
This class models the connection with an ESP8266/ESP32 device over its WebREPL websocket channel.
|
|
"""
|
|
|
|
_file_transfer_buffer_size = 1024
|
|
|
|
class State(enum.IntEnum):
|
|
DISCONNECTED = 1
|
|
CONNECTED = 2
|
|
PASSWORD_REQUIRED = 3
|
|
READY = 4
|
|
SENDING_REQUEST = 5
|
|
WAITING_ECHO = 6
|
|
WAITING_RESPONSE = 7
|
|
SENDING_FILE_TRANSFER_RESPONSE = 8
|
|
WAITING_FILE_TRANSFER_RESPONSE = 9
|
|
UPLOADING_FILE = 10
|
|
DOWNLOADING_FILE = 11
|
|
|
|
class FileRequestType(enum.IntEnum):
|
|
UPLOAD = 1
|
|
DOWNLOAD = 2
|
|
|
|
def __init__(self, host: str, port: int, connect_timeout: Optional[float] = None,
|
|
password: Optional[str] = None, ws: Optional[websocket.WebSocketApp] = None):
|
|
self.host = host
|
|
self.port = port
|
|
self.connect_timeout = connect_timeout
|
|
self.password = password
|
|
self.state = self.State.DISCONNECTED
|
|
self.ws = ws
|
|
self._connected = threading.Event()
|
|
self._logged_in = threading.Event()
|
|
self._echo_received = threading.Event()
|
|
self._response_received = threading.Event()
|
|
self._download_chunk_ready = threading.Event()
|
|
self._file_transfer_request_ack_received = threading.Event()
|
|
self._file_transfer_ack_received = threading.Event()
|
|
self._file_transfer_request_successful = True
|
|
self._file_transfer_successful = True
|
|
self._downloaded_chunks = queue.Queue()
|
|
self._password_requested = False
|
|
self._running_cmd = None
|
|
self._received_echo = None
|
|
self._received_response = None
|
|
self._paste_header_received = False
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
def send(self, msg: Union[str, bytes], wait_response: bool = True, timeout: Optional[float] = None):
|
|
bufsize = 255
|
|
|
|
msg = (msg
|
|
.replace("\n", "\r\n") # end of command in normal mode
|
|
.replace("\\x01", "\x01") # switch to raw mode
|
|
.replace("\\x02", "\x02") # switch to normal mode
|
|
.replace("\\x03", "\x03") # interrupt
|
|
.replace("\\x04", "\x04") # end of command in raw mode
|
|
.encode())
|
|
|
|
if not msg.endswith(b'\r\n'):
|
|
msg += b'\r\n'
|
|
if wait_response:
|
|
# Enter PASTE mode and exit on end-of-message
|
|
msg = b'\x05' + msg + b'\x04'
|
|
|
|
if wait_response:
|
|
self.state = self.State.SENDING_REQUEST
|
|
self._running_cmd = msg.decode().strip()
|
|
self._received_echo = ''
|
|
self._response_received.clear()
|
|
|
|
self._echo_received.clear()
|
|
for chunk in grouper(bufsize, msg):
|
|
self.ws.send(bytes(chunk))
|
|
|
|
if not wait_response:
|
|
return
|
|
|
|
self.state = self.State.WAITING_ECHO
|
|
echo_received = self._echo_received.wait(timeout=timeout)
|
|
if not echo_received:
|
|
self.on_timeout('No response echo received')
|
|
|
|
self._paste_header_received = False
|
|
response_received = self._response_received.wait(timeout=timeout)
|
|
if not response_received:
|
|
self.on_timeout('No response received')
|
|
|
|
response = self._received_response
|
|
self._received_response = None
|
|
return response
|
|
|
|
def on_connect(self):
|
|
self.state = Connection.State.CONNECTED
|
|
self._connected.set()
|
|
|
|
def on_password_requested(self):
|
|
self._password_requested = True
|
|
self.state = Connection.State.PASSWORD_REQUIRED
|
|
assert self.password, 'This device is protected by password and no password was provided'
|
|
self.send(self.password, wait_response=False)
|
|
|
|
def on_ready(self):
|
|
self.state = Connection.State.READY
|
|
self._logged_in.set()
|
|
|
|
def on_close(self):
|
|
self.state = self.State.DISCONNECTED
|
|
self._connected.clear()
|
|
self._logged_in.clear()
|
|
self._password_requested = False
|
|
self.ws = None
|
|
|
|
def on_recv_echo(self, echo):
|
|
def str_transform(s: str):
|
|
s = s.replace('\x05', '').replace('\x04', '').replace('\r', '')
|
|
s = re.sub('^[\s\r\n]+', '', s)
|
|
s = re.sub('[\s\r\n]+$', '', s)
|
|
return s
|
|
|
|
if echo.endswith('\r\n=== ') and not self._paste_header_received:
|
|
self._paste_header_received = True
|
|
return
|
|
|
|
if re.match('\s+>>>\s+', echo) \
|
|
or re.match('\s+\.\.\.\s+', echo) \
|
|
or re.match('\s+===\s+', echo):
|
|
return
|
|
|
|
self._received_echo += echo
|
|
running_cmd = str_transform(self._running_cmd)
|
|
received_echo = str_transform(self._received_echo)
|
|
|
|
if running_cmd == received_echo:
|
|
self._received_echo = None
|
|
self.state = self.State.WAITING_RESPONSE
|
|
self._echo_received.set()
|
|
|
|
def close(self):
|
|
# noinspection PyBroadException
|
|
try:
|
|
self.ws.close()
|
|
except:
|
|
pass
|
|
|
|
self.on_close()
|
|
|
|
def on_timeout(self, msg: str = ''):
|
|
self.close()
|
|
raise TimeoutError(msg)
|
|
|
|
def append_response(self, response):
|
|
if isinstance(response, bytes):
|
|
response = response.decode()
|
|
|
|
if not self._received_response:
|
|
self._received_response = ''
|
|
|
|
self._received_response += response
|
|
|
|
def on_recv_response(self, response):
|
|
self.append_response(response)
|
|
self.state = self.State.READY
|
|
self._received_response = self._received_response.strip()
|
|
|
|
if self._received_response.startswith('=== '):
|
|
# Strip PASTE mode output residual
|
|
self._received_response = self._received_response[4:]
|
|
self._received_response = self._received_response.strip()
|
|
|
|
# Replace \r\n serial end-of-line with \n
|
|
self._received_response = self._received_response.replace('\r\n', '\n')
|
|
|
|
# Notify the listeners
|
|
self._response_received.set()
|
|
|
|
def on_recv_file_transfer_response(self, response):
|
|
self._file_transfer_request_successful = self._parse_file_transfer_response(response)
|
|
self._file_transfer_request_ack_received.set()
|
|
|
|
def on_file_transfer_completed(self, response):
|
|
self._file_transfer_successful = self._parse_file_transfer_response(response)
|
|
self._file_transfer_ack_received.set()
|
|
|
|
def on_file_upload_start(self):
|
|
self.logger.info('Starting file upload')
|
|
self._file_transfer_successful = False
|
|
self._file_transfer_ack_received.clear()
|
|
self.state = self.State.UPLOADING_FILE
|
|
|
|
def on_file_download_start(self):
|
|
self.logger.info('Starting file download')
|
|
self._file_transfer_successful = False
|
|
self._downloaded_chunks = queue.Queue()
|
|
self.state = self.State.DOWNLOADING_FILE
|
|
self._file_transfer_ack_received.clear()
|
|
self.ws.send(b'\x00', opcode=websocket.ABNF.OPCODE_BINARY)
|
|
|
|
def on_chunk_received(self, data):
|
|
size = data[0] | (data[1] << 8)
|
|
data = data[2:]
|
|
if len(data) != size:
|
|
return
|
|
|
|
self.logger.info('Received chunk of size {} (total size={})'.format(len(data), size))
|
|
if size == 0:
|
|
# End of file
|
|
self.logger.info('File download completed')
|
|
self._downloaded_chunks.put(None)
|
|
self.on_file_download_completed()
|
|
else:
|
|
self._downloaded_chunks.put(data)
|
|
self.ws.send(b'\x00', opcode=websocket.ABNF.OPCODE_BINARY)
|
|
|
|
self._download_chunk_ready.set()
|
|
self._download_chunk_ready.clear()
|
|
|
|
def on_file_download_completed(self):
|
|
self.state = self.State.READY
|
|
|
|
def on_file_transfer_request(self):
|
|
self.logger.info('Sending file transfer request')
|
|
self._file_transfer_request_successful = False
|
|
self._file_transfer_request_ack_received.clear()
|
|
self.state = self.State.SENDING_FILE_TRANSFER_RESPONSE
|
|
|
|
def get_downloaded_chunks(self, timeout: Optional[float] = None):
|
|
while True:
|
|
try:
|
|
chunk = self._downloaded_chunks.get(timeout=timeout)
|
|
except queue.Empty:
|
|
self.on_timeout('File download timed out')
|
|
|
|
if chunk is None:
|
|
break
|
|
|
|
yield chunk
|
|
|
|
def wait_ready(self):
|
|
connected = self._connected.wait(timeout=self.connect_timeout)
|
|
if not connected:
|
|
self.on_timeout('Connection timed out')
|
|
|
|
logged_in = self._logged_in.wait(timeout=self.connect_timeout)
|
|
if not logged_in:
|
|
self.on_timeout('Log in timed out')
|
|
|
|
def wait_file_request_ack_received(self, timeout):
|
|
self.state = self.State.WAITING_FILE_TRANSFER_RESPONSE
|
|
self._file_transfer_request_ack_received.wait(timeout=timeout)
|
|
assert self._file_transfer_request_successful, 'File transfer request failed'
|
|
self.logger.info('File transfer request acknowledged')
|
|
|
|
def wait_file_transfer_completed(self, timeout):
|
|
self._file_transfer_ack_received.wait(timeout)
|
|
assert self._file_transfer_successful, 'File transfer failed'
|
|
self.logger.info('File transfer completed')
|
|
|
|
@staticmethod
|
|
def _parse_file_transfer_response(response: bytes) -> bool:
|
|
if not response or len(response) < 4:
|
|
return False
|
|
|
|
if response[0] == ord('W') and response[1] == ord('B'):
|
|
return response[2] | response[3] << 8 == 0
|
|
return False
|
|
|
|
def _send_file_request(self, filename: str, request_type: FileRequestType, file_size: int = 0,
|
|
timeout: Optional[float] = None):
|
|
self.on_file_transfer_request()
|
|
|
|
# 2 + 1 + 1 + 8 + 4 + 2 + 64
|
|
request = bytearray(82)
|
|
|
|
# Protocol mode (file transfer) and request type (1=PUT, 2=GET)
|
|
request[0] = ord('W')
|
|
request[1] = ord('A')
|
|
request[2] = request_type.value
|
|
|
|
# File size encoding
|
|
request[12] = file_size & 0xff
|
|
request[13] = (file_size >> 8) & 0xff
|
|
request[14] = (file_size >> 16) & 0xff
|
|
request[15] = (file_size >> 24) & 0xff
|
|
|
|
# File name length encoding
|
|
request[16] = len(filename) & 0xff
|
|
request[17] = (len(filename) >> 8) & 0xff
|
|
|
|
# File name encoding
|
|
for i in range(0, min(64, len(filename))):
|
|
request[i+18] = ord(filename[i])
|
|
|
|
self.ws.send(request, opcode=websocket.ABNF.OPCODE_BINARY)
|
|
self.wait_file_request_ack_received(timeout=timeout)
|
|
|
|
def _upload_file(self, f, timeout):
|
|
self.on_file_upload_start()
|
|
content = f.read(self._file_transfer_buffer_size)
|
|
|
|
while content:
|
|
self.ws.send(content, opcode=websocket.ABNF.OPCODE_BINARY)
|
|
content = f.read(self._file_transfer_buffer_size)
|
|
|
|
self.wait_file_transfer_completed(timeout=timeout)
|
|
self.state = self.State.READY
|
|
|
|
def _download_file(self, f, timeout):
|
|
self.on_file_download_start()
|
|
for chunk in self.get_downloaded_chunks(timeout=timeout):
|
|
f.write(chunk)
|
|
|
|
self.on_file_download_completed()
|
|
|
|
def file_upload(self, source, destination, timeout):
|
|
source = os.path.abspath(os.path.expanduser(source))
|
|
destination = os.path.join(destination, os.path.basename(source)) if destination else os.path.basename(source)
|
|
size = os.path.getsize(source)
|
|
|
|
with open(source, 'rb') as f:
|
|
self._send_file_request(destination, self.FileRequestType.UPLOAD, file_size=size, timeout=timeout)
|
|
self._upload_file(f, timeout=timeout)
|
|
|
|
def file_download(self, file, fd, timeout=None):
|
|
self._send_file_request(file, request_type=self.FileRequestType.DOWNLOAD, timeout=timeout)
|
|
self._download_file(fd, timeout=timeout)
|
|
|
|
|
|
# vim:sw=4:ts=4:et:
|