platypush/platypush/plugins/esp/models/connection.py

343 lines
12 KiB
Python
Raw Normal View History

import enum
import logging
import queue
import os
import re
import threading
import websocket
from typing import Optional, Union
from platypush.utils import grouper
class Connection:
"""
This class models the connection with an ESP8266/ESP32 device over its WebREPL websocket channel.
"""
_file_transfer_buffer_size = 1024
class State(enum.IntEnum):
DISCONNECTED = 1
CONNECTED = 2
PASSWORD_REQUIRED = 3
READY = 4
SENDING_REQUEST = 5
WAITING_ECHO = 6
WAITING_RESPONSE = 7
SENDING_FILE_TRANSFER_RESPONSE = 8
WAITING_FILE_TRANSFER_RESPONSE = 9
UPLOADING_FILE = 10
DOWNLOADING_FILE = 11
class FileRequestType(enum.IntEnum):
UPLOAD = 1
DOWNLOAD = 2
def __init__(self, host: str, port: int, connect_timeout: Optional[float] = None,
password: Optional[str] = None, ws: Optional[websocket.WebSocketApp] = None):
self.host = host
self.port = port
self.connect_timeout = connect_timeout
self.password = password
self.state = self.State.DISCONNECTED
self.ws = ws
self._connected = threading.Event()
self._logged_in = threading.Event()
self._echo_received = threading.Event()
self._response_received = threading.Event()
self._download_chunk_ready = threading.Event()
self._file_transfer_request_ack_received = threading.Event()
self._file_transfer_ack_received = threading.Event()
self._file_transfer_request_successful = True
self._file_transfer_successful = True
self._downloaded_chunks = queue.Queue()
self._password_requested = False
self._running_cmd = None
self._received_echo = None
self._received_response = None
self._paste_header_received = False
self.logger = logging.getLogger(__name__)
def send(self, msg: Union[str, bytes], wait_response: bool = True, timeout: Optional[float] = None):
bufsize = 255
msg = (msg
.replace("\n", "\r\n") # end of command in normal mode
.replace("\\x01", "\x01") # switch to raw mode
.replace("\\x02", "\x02") # switch to normal mode
.replace("\\x03", "\x03") # interrupt
.replace("\\x04", "\x04") # end of command in raw mode
.encode())
if not msg.endswith(b'\r\n'):
msg += b'\r\n'
if wait_response:
# Enter PASTE mode and exit on end-of-message
msg = b'\x05' + msg + b'\x04'
if wait_response:
self.state = self.State.SENDING_REQUEST
self._running_cmd = msg.decode().strip()
self._received_echo = ''
self._response_received.clear()
self._echo_received.clear()
for chunk in grouper(bufsize, msg):
self.ws.send(bytes(chunk))
if not wait_response:
return
self.state = self.State.WAITING_ECHO
echo_received = self._echo_received.wait(timeout=timeout)
if not echo_received:
self.on_timeout('No response echo received')
self._paste_header_received = False
response_received = self._response_received.wait(timeout=timeout)
if not response_received:
self.on_timeout('No response received')
response = self._received_response
self._received_response = None
return response
def on_connect(self):
self.state = Connection.State.CONNECTED
self._connected.set()
def on_password_requested(self):
self._password_requested = True
self.state = Connection.State.PASSWORD_REQUIRED
assert self.password, 'This device is protected by password and no password was provided'
self.send(self.password, wait_response=False)
def on_ready(self):
self.state = Connection.State.READY
self._logged_in.set()
def on_close(self):
self.state = self.State.DISCONNECTED
self._connected.clear()
self._logged_in.clear()
self._password_requested = False
self.ws = None
def on_recv_echo(self, echo):
def str_transform(s: str):
s = s.replace('\x05', '').replace('\x04', '').replace('\r', '')
s = re.sub('^[\s\r\n]+', '', s)
s = re.sub('[\s\r\n]+$', '', s)
return s
if echo.endswith('\r\n=== ') and not self._paste_header_received:
self._paste_header_received = True
return
if re.match('\s+>>>\s+', echo) \
or re.match('\s+\.\.\.\s+', echo) \
or re.match('\s+===\s+', echo):
return
self._received_echo += echo
running_cmd = str_transform(self._running_cmd)
received_echo = str_transform(self._received_echo)
if running_cmd == received_echo:
self._received_echo = None
self.state = self.State.WAITING_RESPONSE
self._echo_received.set()
def close(self):
# noinspection PyBroadException
try:
self.ws.close()
except:
pass
self.on_close()
def on_timeout(self, msg: str = ''):
self.close()
raise TimeoutError(msg)
def append_response(self, response):
if isinstance(response, bytes):
response = response.decode()
if not self._received_response:
self._received_response = ''
self._received_response += response
def on_recv_response(self, response):
self.append_response(response)
self.state = self.State.READY
self._received_response = self._received_response.strip()
if self._received_response.startswith('=== '):
# Strip PASTE mode output residual
self._received_response = self._received_response[4:]
self._received_response = self._received_response.strip()
# Replace \r\n serial end-of-line with \n
self._received_response = self._received_response.replace('\r\n', '\n')
# Notify the listeners
self._response_received.set()
def on_recv_file_transfer_response(self, response):
self._file_transfer_request_successful = self._parse_file_transfer_response(response)
self._file_transfer_request_ack_received.set()
def on_file_transfer_completed(self, response):
self._file_transfer_successful = self._parse_file_transfer_response(response)
self._file_transfer_ack_received.set()
def on_file_upload_start(self):
self.logger.info('Starting file upload')
self._file_transfer_successful = False
self._file_transfer_ack_received.clear()
self.state = self.State.UPLOADING_FILE
def on_file_download_start(self):
self.logger.info('Starting file download')
self._file_transfer_successful = False
self._downloaded_chunks = queue.Queue()
self.state = self.State.DOWNLOADING_FILE
self._file_transfer_ack_received.clear()
self.ws.send(b'\x00', opcode=websocket.ABNF.OPCODE_BINARY)
def on_chunk_received(self, data):
size = data[0] | (data[1] << 8)
data = data[2:]
if len(data) != size:
return
self.logger.info('Received chunk of size {} (total size={})'.format(len(data), size))
if size == 0:
# End of file
self.logger.info('File download completed')
self._downloaded_chunks.put(None)
self.on_file_download_completed()
else:
self._downloaded_chunks.put(data)
self.ws.send(b'\x00', opcode=websocket.ABNF.OPCODE_BINARY)
self._download_chunk_ready.set()
self._download_chunk_ready.clear()
def on_file_download_completed(self):
self.state = self.State.READY
def on_file_transfer_request(self):
self.logger.info('Sending file transfer request')
self._file_transfer_request_successful = False
self._file_transfer_request_ack_received.clear()
self.state = self.State.SENDING_FILE_TRANSFER_RESPONSE
def get_downloaded_chunks(self, timeout: Optional[float] = None):
while True:
try:
chunk = self._downloaded_chunks.get(timeout=timeout)
except queue.Empty:
self.on_timeout('File download timed out')
if chunk is None:
break
yield chunk
def wait_ready(self):
connected = self._connected.wait(timeout=self.connect_timeout)
if not connected:
self.on_timeout('Connection timed out')
logged_in = self._logged_in.wait(timeout=self.connect_timeout)
if not logged_in:
self.on_timeout('Log in timed out')
def wait_file_request_ack_received(self, timeout):
self.state = self.State.WAITING_FILE_TRANSFER_RESPONSE
self._file_transfer_request_ack_received.wait(timeout=timeout)
assert self._file_transfer_request_successful, 'File transfer request failed'
self.logger.info('File transfer request acknowledged')
def wait_file_transfer_completed(self, timeout):
self._file_transfer_ack_received.wait(timeout)
assert self._file_transfer_successful, 'File transfer failed'
self.logger.info('File transfer completed')
@staticmethod
def _parse_file_transfer_response(response: bytes) -> bool:
if not response or len(response) < 4:
return False
if response[0] == ord('W') and response[1] == ord('B'):
return response[2] | response[3] << 8 == 0
return False
def _send_file_request(self, filename: str, request_type: FileRequestType, file_size: int = 0,
timeout: Optional[float] = None):
self.on_file_transfer_request()
# 2 + 1 + 1 + 8 + 4 + 2 + 64
request = bytearray(82)
# Protocol mode (file transfer) and request type (1=PUT, 2=GET)
request[0] = ord('W')
request[1] = ord('A')
request[2] = request_type.value
# File size encoding
request[12] = file_size & 0xff
request[13] = (file_size >> 8) & 0xff
request[14] = (file_size >> 16) & 0xff
request[15] = (file_size >> 24) & 0xff
# File name length encoding
request[16] = len(filename) & 0xff
request[17] = (len(filename) >> 8) & 0xff
# File name encoding
for i in range(0, min(64, len(filename))):
request[i+18] = ord(filename[i])
self.ws.send(request, opcode=websocket.ABNF.OPCODE_BINARY)
self.wait_file_request_ack_received(timeout=timeout)
def _upload_file(self, f, timeout):
self.on_file_upload_start()
content = f.read(self._file_transfer_buffer_size)
while content:
self.ws.send(content, opcode=websocket.ABNF.OPCODE_BINARY)
content = f.read(self._file_transfer_buffer_size)
self.wait_file_transfer_completed(timeout=timeout)
self.state = self.State.READY
def _download_file(self, f, timeout):
self.on_file_download_start()
for chunk in self.get_downloaded_chunks(timeout=timeout):
f.write(chunk)
self.on_file_download_completed()
def file_upload(self, source, destination, timeout):
source = os.path.abspath(os.path.expanduser(source))
destination = os.path.join(destination, os.path.basename(source)) if destination else os.path.basename(source)
size = os.path.getsize(source)
with open(source, 'rb') as f:
self._send_file_request(destination, self.FileRequestType.UPLOAD, file_size=size, timeout=timeout)
self._upload_file(f, timeout=timeout)
def file_download(self, file, fd, timeout=None):
self._send_file_request(file, request_type=self.FileRequestType.DOWNLOAD, timeout=timeout)
self._download_file(fd, timeout=timeout)
# vim:sw=4:ts=4:et: