Merge branch 'master' into 29-generic-entities-support

This commit is contained in:
Fabio Manganiello 2022-11-21 12:36:01 +01:00
commit ba1681fc22
6 changed files with 51 additions and 56 deletions

View file

@ -153,7 +153,10 @@ class CronScheduler(threading.Thread):
for (job_name, job_config) in self.jobs_config.items(): for (job_name, job_config) in self.jobs_config.items():
job = self._get_job(name=job_name, config=job_config) job = self._get_job(name=job_name, config=job_config)
if job.state == CronjobState.IDLE: if job.state == CronjobState.IDLE:
job.start() try:
job.start()
except Exception as e:
logger.warning(f'Could not start cronjob {job_name}: {e}')
t_before_wait = get_now().timestamp() t_before_wait = get_now().timestamp()
self._should_stop.wait(timeout=self._poll_seconds) self._should_stop.wait(timeout=self._poll_seconds)

View file

@ -158,7 +158,8 @@ class AsyncRunnablePlugin(RunnablePlugin, ABC):
asyncio.set_event_loop(self._loop) asyncio.set_event_loop(self._loop)
self._task = self._loop.create_task(self._listen()) self._task = self._loop.create_task(self._listen())
self._task.set_name(self.__class__.__name__ + '.listen') if hasattr(self._task, 'set_name'):
self._task.set_name(self.__class__.__name__ + '.listen')
self._loop.run_forever() self._loop.run_forever()
def main(self): def main(self):

View file

@ -1,17 +1,14 @@
import base64
import datetime import datetime
import hashlib import hashlib
import json
import random import random
import rsa
import time import time
from typing import Optional, Dict from typing import Optional, Dict
import bcrypt import bcrypt
try:
from jwt.exceptions import PyJWTError
from jwt import encode as jwt_encode, decode as jwt_decode
except ImportError:
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import make_transient from sqlalchemy.orm import make_transient
@ -229,17 +226,20 @@ class UserManager:
if not user: if not user:
raise InvalidCredentialsException() raise InvalidCredentialsException()
_, priv_key = get_or_generate_jwt_rsa_key_pair() pub_key, _ = get_or_generate_jwt_rsa_key_pair()
payload = { payload = json.dumps(
'username': username, {
'created_at': datetime.datetime.now().timestamp(), 'username': username,
'expires_at': expires_at.timestamp() if expires_at else None, 'created_at': datetime.datetime.now().timestamp(),
} 'expires_at': expires_at.timestamp() if expires_at else None,
},
sort_keys=True,
indent=None,
)
token = jwt_encode(payload, priv_key, algorithm='RS256') return base64.b64encode(
if isinstance(token, bytes): rsa.encrypt(payload.encode('ascii'), pub_key)
token = token.decode() ).decode()
return token
@staticmethod @staticmethod
def validate_jwt_token(token: str) -> Dict[str, str]: def validate_jwt_token(token: str) -> Dict[str, str]:
@ -259,12 +259,17 @@ class UserManager:
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token. :raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
""" """
pub_key, _ = get_or_generate_jwt_rsa_key_pair() _, priv_key = get_or_generate_jwt_rsa_key_pair()
try: try:
payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256']) # type: ignore[reportGeneralTypeIssues] payload = json.loads(
except PyJWTError as e: rsa.decrypt(
raise InvalidJWTTokenException(str(e)) base64.b64decode(token.encode('ascii')),
priv_key
).decode('ascii')
)
except (TypeError, ValueError) as e:
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}')
expires_at = payload.get('expires_at') expires_at = payload.get('expires_at')
if expires_at and time.time() > expires_at: if expires_at and time.time() > expires_at:

View file

@ -8,6 +8,7 @@ import logging
import os import os
import pathlib import pathlib
import re import re
import rsa
import signal import signal
import socket import socket
import ssl import ssl
@ -16,6 +17,7 @@ from typing import Optional, Tuple, Union
from dateutil import parser, tz from dateutil import parser, tz
from redis import Redis from redis import Redis
from rsa.key import PublicKey, PrivateKey
logger = logging.getLogger('utils') logger = logging.getLogger('utils')
@ -387,9 +389,8 @@ def run(action, *args, **kwargs):
return response.output return response.output
def generate_rsa_key_pair( def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) \
key_file: Optional[str] = None, size: int = 2048 -> Tuple[PublicKey, PrivateKey]:
) -> Tuple[str, str]:
""" """
Generate an RSA key pair. Generate an RSA key pair.
@ -407,30 +408,11 @@ def generate_rsa_key_pair(
:param size: Key size (default: 2048 bits). :param size: Key size (default: 2048 bits).
:return: A tuple with the generated ``(priv_key_str, pub_key_str)``. :return: A tuple with the generated ``(priv_key_str, pub_key_str)``.
""" """
from cryptography.hazmat.primitives import serialization logger.info('Generating RSA keypair')
from cryptography.hazmat.primitives.asymmetric import rsa pub_key, priv_key = rsa.newkeys(size)
from cryptography.hazmat.backends import default_backend logger.info('Generated RSA keypair')
public_key_str = pub_key.save_pkcs1('PEM').decode()
public_exp = 65537 private_key_str = priv_key.save_pkcs1('PEM').decode()
private_key = rsa.generate_private_key(
public_exponent=public_exp, key_size=size, backend=default_backend()
)
logger.info('Generating RSA {} key pair'.format(size))
private_key_str = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
).decode()
public_key_str = (
private_key.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1,
)
.decode()
)
if key_file: if key_file:
logger.info('Saving private key to {}'.format(key_file)) logger.info('Saving private key to {}'.format(key_file))
@ -441,7 +423,7 @@ def generate_rsa_key_pair(
f2.write(public_key_str) f2.write(public_key_str)
os.chmod(key_file, 0o600) os.chmod(key_file, 0o600)
return public_key_str, private_key_str return pub_key, priv_key
def get_or_generate_jwt_rsa_key_pair(): def get_or_generate_jwt_rsa_key_pair():
@ -452,8 +434,14 @@ def get_or_generate_jwt_rsa_key_pair():
pub_key_file = priv_key_file + '.pub' pub_key_file = priv_key_file + '.pub'
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file): if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
with open(pub_key_file, 'r') as f1, open(priv_key_file, 'r') as f2: with (
return f1.read(), f2.read() open(pub_key_file, 'r') as f1,
open(priv_key_file, 'r') as f2
):
return (
rsa.PublicKey.load_pkcs1(f1.read().encode()),
rsa.PrivateKey.load_pkcs1(f2.read().encode()),
)
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755) pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
return generate_rsa_key_pair(priv_key_file, size=2048) return generate_rsa_key_pair(priv_key_file, size=2048)

View file

@ -14,8 +14,7 @@ frozendict
requests requests
sqlalchemy sqlalchemy
bcrypt bcrypt
cryptography rsa
pyjwt
zeroconf zeroconf
paho-mqtt paho-mqtt
websocket-client websocket-client

View file

@ -64,8 +64,7 @@ setup(
'zeroconf>=0.27.0', 'zeroconf>=0.27.0',
'tz', 'tz',
'python-dateutil', 'python-dateutil',
# 'cryptography', 'rsa',
'pyjwt',
'marshmallow', 'marshmallow',
'frozendict', 'frozendict',
'flask', 'flask',