LINT/regex fix for ESP plugin.

This commit is contained in:
Fabio Manganiello 2023-09-17 17:10:40 +02:00
parent ecba2e49ac
commit c6cda86b1c
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774

View file

@ -4,9 +4,10 @@ import queue
import os import os
import re import re
import threading import threading
import websocket
from typing import Optional, Union from typing import Optional, Union
import websocket
from platypush.utils import grouper from platypush.utils import grouper
@ -34,8 +35,14 @@ class Connection:
UPLOAD = 1 UPLOAD = 1
DOWNLOAD = 2 DOWNLOAD = 2
def __init__(self, host: str, port: int, connect_timeout: Optional[float] = None, def __init__(
password: Optional[str] = None, ws: Optional[websocket.WebSocketApp] = None): self,
host: str,
port: int,
connect_timeout: Optional[float] = None,
password: Optional[str] = None,
ws: Optional[websocket.WebSocketApp] = None,
):
self.host = host self.host = host
self.port = port self.port = port
self.connect_timeout = connect_timeout self.connect_timeout = connect_timeout
@ -59,16 +66,22 @@ class Connection:
self._paste_header_received = False self._paste_header_received = False
self.logger = logging.getLogger('platypush:plugin:esp') self.logger = logging.getLogger('platypush:plugin:esp')
def send(self, msg: Union[str, bytes], wait_response: bool = True, timeout: Optional[float] = None): def send(
self,
msg: Union[str, bytes],
wait_response: bool = True,
timeout: Optional[float] = None,
):
bufsize = 255 bufsize = 255
msg = (msg msg = (
.replace("\n", "\r\n") # end of command in normal mode msg.replace("\n", "\r\n") # end of command in normal mode
.replace("\\x01", "\x01") # switch to raw mode .replace("\\x01", "\x01") # switch to raw mode
.replace("\\x02", "\x02") # switch to normal mode .replace("\\x02", "\x02") # switch to normal mode
.replace("\\x03", "\x03") # interrupt .replace("\\x03", "\x03") # interrupt
.replace("\\x04", "\x04") # end of command in raw mode .replace("\\x04", "\x04") # end of command in raw mode
.encode()) .encode()
)
if not msg.endswith(b'\r\n'): if not msg.endswith(b'\r\n'):
msg += b'\r\n' msg += b'\r\n'
@ -110,7 +123,9 @@ class Connection:
def on_password_requested(self): def on_password_requested(self):
self._password_requested = True self._password_requested = True
self.state = Connection.State.PASSWORD_REQUIRED self.state = Connection.State.PASSWORD_REQUIRED
assert self.password, 'This device is protected by password and no password was provided' assert (
self.password
), 'This device is protected by password and no password was provided'
self.send(self.password, wait_response=False) self.send(self.password, wait_response=False)
def on_ready(self): def on_ready(self):
@ -127,17 +142,19 @@ class Connection:
def on_recv_echo(self, echo): def on_recv_echo(self, echo):
def str_transform(s: str): def str_transform(s: str):
s = s.replace('\x05', '').replace('\x04', '').replace('\r', '') s = s.replace('\x05', '').replace('\x04', '').replace('\r', '')
s = re.sub('^[\s\r\n]+', '', s) s = re.sub(r'^[\s\r\n]+', '', s)
s = re.sub('[\s\r\n]+$', '', s) s = re.sub(r'[\s\r\n]+$', '', s)
return s return s
if echo.endswith('\r\n=== ') and not self._paste_header_received: if echo.endswith('\r\n=== ') and not self._paste_header_received:
self._paste_header_received = True self._paste_header_received = True
return return
if re.match('\s+>>>\s+', echo) \ if (
or re.match('\s+\.\.\.\s+', echo) \ re.match(r'\s+>>>\s+', echo)
or re.match('\s+===\s+', echo): or re.match(r'\s+\.\.\.\s+', echo)
or re.match(r'\s+===\s+', echo)
):
return return
self._received_echo += echo self._received_echo += echo
@ -187,7 +204,9 @@ class Connection:
self._response_received.set() self._response_received.set()
def on_recv_file_transfer_response(self, response): def on_recv_file_transfer_response(self, response):
self._file_transfer_request_successful = self._parse_file_transfer_response(response) self._file_transfer_request_successful = self._parse_file_transfer_response(
response
)
self._file_transfer_request_ack_received.set() self._file_transfer_request_ack_received.set()
def on_file_transfer_completed(self, response): def on_file_transfer_completed(self, response):
@ -214,7 +233,9 @@ class Connection:
if len(data) != size: if len(data) != size:
return return
self.logger.info('Received chunk of size {} (total size={})'.format(len(data), size)) self.logger.info(
'Received chunk of size {} (total size={})'.format(len(data), size)
)
if size == 0: if size == 0:
# End of file # End of file
self.logger.info('File download completed') self.logger.info('File download completed')
@ -242,7 +263,6 @@ class Connection:
chunk = self._downloaded_chunks.get(timeout=timeout) chunk = self._downloaded_chunks.get(timeout=timeout)
except queue.Empty: except queue.Empty:
self.on_timeout('File download timed out') self.on_timeout('File download timed out')
break
if chunk is None: if chunk is None:
break break
@ -278,8 +298,13 @@ class Connection:
return response[2] | response[3] << 8 == 0 return response[2] | response[3] << 8 == 0
return False return False
def _send_file_request(self, filename: str, request_type: FileRequestType, file_size: int = 0, def _send_file_request(
timeout: Optional[float] = None): self,
filename: str,
request_type: FileRequestType,
file_size: int = 0,
timeout: Optional[float] = None,
):
self.on_file_transfer_request() self.on_file_transfer_request()
# 2 + 1 + 1 + 8 + 4 + 2 + 64 # 2 + 1 + 1 + 8 + 4 + 2 + 64
@ -291,14 +316,14 @@ class Connection:
request[2] = request_type.value request[2] = request_type.value
# File size encoding # File size encoding
request[12] = file_size & 0xff request[12] = file_size & 0xFF
request[13] = (file_size >> 8) & 0xff request[13] = (file_size >> 8) & 0xFF
request[14] = (file_size >> 16) & 0xff request[14] = (file_size >> 16) & 0xFF
request[15] = (file_size >> 24) & 0xff request[15] = (file_size >> 24) & 0xFF
# File name length encoding # File name length encoding
request[16] = len(filename) & 0xff request[16] = len(filename) & 0xFF
request[17] = (len(filename) >> 8) & 0xff request[17] = (len(filename) >> 8) & 0xFF
# File name encoding # File name encoding
for i in range(0, min(64, len(filename))): for i in range(0, min(64, len(filename))):
@ -327,15 +352,26 @@ class Connection:
def file_upload(self, source, destination, timeout): def file_upload(self, source, destination, timeout):
source = os.path.abspath(os.path.expanduser(source)) source = os.path.abspath(os.path.expanduser(source))
destination = os.path.join(destination, os.path.basename(source)) if destination else os.path.basename(source) destination = (
os.path.join(destination, os.path.basename(source))
if destination
else os.path.basename(source)
)
size = os.path.getsize(source) size = os.path.getsize(source)
with open(source, 'rb') as f: with open(source, 'rb') as f:
self._send_file_request(destination, self.FileRequestType.UPLOAD, file_size=size, timeout=timeout) self._send_file_request(
destination,
self.FileRequestType.UPLOAD,
file_size=size,
timeout=timeout,
)
self._upload_file(f, timeout=timeout) self._upload_file(f, timeout=timeout)
def file_download(self, file, fd, timeout=None): def file_download(self, file, fd, timeout=None):
self._send_file_request(file, request_type=self.FileRequestType.DOWNLOAD, timeout=timeout) self._send_file_request(
file, request_type=self.FileRequestType.DOWNLOAD, timeout=timeout
)
self._download_file(fd, timeout=timeout) self._download_file(fd, timeout=timeout)