Removed PyJWT dependency.

PyJWT is a very brittle and cumbersome dependency that expects several
cryptography libraries to be already installed on the system, and it can
lead to hard-to-debug errors when ported to different systems.

Moreover, it installs the whole `cryptography` package, which is several
MBs in size, takes time to compile, and it requires a Rust compiler to
be present on the target machine.

Platypush will now use the Python-native `rsa` module to handle JWT
tokens.
This commit is contained in:
Fabio Manganiello 2022-11-21 12:30:38 +01:00
parent 02f89258b8
commit a2c8e27bd8
4 changed files with 45 additions and 51 deletions

View file

@ -1,17 +1,14 @@
import base64
import datetime import datetime
import hashlib import hashlib
import json
import random import random
import rsa
import time import time
from typing import Optional, Dict from typing import Optional, Dict
import bcrypt import bcrypt
try:
from jwt.exceptions import PyJWTError
from jwt import encode as jwt_encode, decode as jwt_decode
except ImportError:
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
@ -197,17 +194,20 @@ class UserManager:
if not user: if not user:
raise InvalidCredentialsException() raise InvalidCredentialsException()
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair() pub_key, _ = get_or_generate_jwt_rsa_key_pair()
payload = { payload = json.dumps(
'username': username, {
'created_at': datetime.datetime.now().timestamp(), 'username': username,
'expires_at': expires_at.timestamp() if expires_at else None, 'created_at': datetime.datetime.now().timestamp(),
} 'expires_at': expires_at.timestamp() if expires_at else None,
},
sort_keys=True,
indent=None,
)
token = jwt_encode(payload, priv_key, algorithm='RS256') return base64.b64encode(
if isinstance(token, bytes): rsa.encrypt(payload.encode('ascii'), pub_key)
token = token.decode() ).decode()
return token
@staticmethod @staticmethod
def validate_jwt_token(token: str) -> Dict[str, str]: def validate_jwt_token(token: str) -> Dict[str, str]:
@ -227,12 +227,17 @@ class UserManager:
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token. :raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
""" """
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair() _, priv_key = get_or_generate_jwt_rsa_key_pair()
try: try:
payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256']) payload = json.loads(
except PyJWTError as e: rsa.decrypt(
raise InvalidJWTTokenException(str(e)) base64.b64decode(token.encode('ascii')),
priv_key
).decode('ascii')
)
except (TypeError, ValueError) as e:
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}')
expires_at = payload.get('expires_at') expires_at = payload.get('expires_at')
if expires_at and time.time() > expires_at: if expires_at and time.time() > expires_at:

View file

@ -7,6 +7,7 @@ import logging
import os import os
import pathlib import pathlib
import re import re
import rsa
import signal import signal
import socket import socket
import ssl import ssl
@ -15,6 +16,7 @@ from typing import Optional, Tuple, Union
from dateutil import parser, tz from dateutil import parser, tz
from redis import Redis from redis import Redis
from rsa.key import PublicKey, PrivateKey
logger = logging.getLogger('utils') logger = logging.getLogger('utils')
@ -366,7 +368,8 @@ def run(action, *args, **kwargs):
return response.output return response.output
def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> Tuple[str, str]: def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) \
-> Tuple[PublicKey, PrivateKey]:
""" """
Generate an RSA key pair. Generate an RSA key pair.
@ -384,28 +387,11 @@ def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> T
:param size: Key size (default: 2048 bits). :param size: Key size (default: 2048 bits).
:return: A tuple with the generated ``(priv_key_str, pub_key_str)``. :return: A tuple with the generated ``(priv_key_str, pub_key_str)``.
""" """
from cryptography.hazmat.primitives import serialization logger.info('Generating RSA keypair')
from cryptography.hazmat.primitives.asymmetric import rsa pub_key, priv_key = rsa.newkeys(size)
from cryptography.hazmat.backends import default_backend logger.info('Generated RSA keypair')
public_key_str = pub_key.save_pkcs1('PEM').decode()
public_exp = 65537 private_key_str = priv_key.save_pkcs1('PEM').decode()
private_key = rsa.generate_private_key(
public_exponent=public_exp,
key_size=size,
backend=default_backend()
)
logger.info('Generating RSA {} key pair'.format(size))
private_key_str = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode()
public_key_str = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1,
).decode()
if key_file: if key_file:
logger.info('Saving private key to {}'.format(key_file)) logger.info('Saving private key to {}'.format(key_file))
@ -415,7 +401,7 @@ def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> T
f2.write(public_key_str) f2.write(public_key_str)
os.chmod(key_file, 0o600) os.chmod(key_file, 0o600)
return public_key_str, private_key_str return pub_key, priv_key
def get_or_generate_jwt_rsa_key_pair(): def get_or_generate_jwt_rsa_key_pair():
@ -426,9 +412,14 @@ def get_or_generate_jwt_rsa_key_pair():
pub_key_file = priv_key_file + '.pub' pub_key_file = priv_key_file + '.pub'
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file): if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
with open(pub_key_file, 'r') as f1, \ with (
open(priv_key_file, 'r') as f2: open(pub_key_file, 'r') as f1,
return f1.read(), f2.read() open(priv_key_file, 'r') as f2
):
return (
rsa.PublicKey.load_pkcs1(f1.read().encode()),
rsa.PrivateKey.load_pkcs1(f2.read().encode()),
)
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755) pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
return generate_rsa_key_pair(priv_key_file, size=2048) return generate_rsa_key_pair(priv_key_file, size=2048)

View file

@ -14,8 +14,7 @@ frozendict
requests requests
sqlalchemy sqlalchemy
bcrypt bcrypt
cryptography rsa
pyjwt
zeroconf zeroconf
paho-mqtt paho-mqtt
websocket-client websocket-client

View file

@ -64,8 +64,7 @@ setup(
'zeroconf>=0.27.0', 'zeroconf>=0.27.0',
'tz', 'tz',
'python-dateutil', 'python-dateutil',
# 'cryptography', 'rsa',
'pyjwt',
'marshmallow', 'marshmallow',
'frozendict', 'frozendict',
'flask', 'flask',