forked from platypush/platypush
Removed PyJWT dependency.
PyJWT is a very brittle and cumbersome dependency that expects several cryptography libraries to be already installed on the system, and it can lead to hard-to-debug errors when ported to different systems. Moreover, it installs the whole `cryptography` package, which is several MBs in size, takes time to compile, and it requires a Rust compiler to be present on the target machine. Platypush will now use the Python-native `rsa` module to handle JWT tokens.
This commit is contained in:
parent
02f89258b8
commit
a2c8e27bd8
4 changed files with 45 additions and 51 deletions
|
@ -1,17 +1,14 @@
|
|||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
import random
|
||||
import rsa
|
||||
import time
|
||||
from typing import Optional, Dict
|
||||
|
||||
import bcrypt
|
||||
|
||||
try:
|
||||
from jwt.exceptions import PyJWTError
|
||||
from jwt import encode as jwt_encode, decode as jwt_decode
|
||||
except ImportError:
|
||||
from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
@ -197,17 +194,20 @@ class UserManager:
|
|||
if not user:
|
||||
raise InvalidCredentialsException()
|
||||
|
||||
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
|
||||
payload = {
|
||||
'username': username,
|
||||
'created_at': datetime.datetime.now().timestamp(),
|
||||
'expires_at': expires_at.timestamp() if expires_at else None,
|
||||
}
|
||||
pub_key, _ = get_or_generate_jwt_rsa_key_pair()
|
||||
payload = json.dumps(
|
||||
{
|
||||
'username': username,
|
||||
'created_at': datetime.datetime.now().timestamp(),
|
||||
'expires_at': expires_at.timestamp() if expires_at else None,
|
||||
},
|
||||
sort_keys=True,
|
||||
indent=None,
|
||||
)
|
||||
|
||||
token = jwt_encode(payload, priv_key, algorithm='RS256')
|
||||
if isinstance(token, bytes):
|
||||
token = token.decode()
|
||||
return token
|
||||
return base64.b64encode(
|
||||
rsa.encrypt(payload.encode('ascii'), pub_key)
|
||||
).decode()
|
||||
|
||||
@staticmethod
|
||||
def validate_jwt_token(token: str) -> Dict[str, str]:
|
||||
|
@ -227,12 +227,17 @@ class UserManager:
|
|||
|
||||
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
|
||||
"""
|
||||
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
|
||||
_, priv_key = get_or_generate_jwt_rsa_key_pair()
|
||||
|
||||
try:
|
||||
payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256'])
|
||||
except PyJWTError as e:
|
||||
raise InvalidJWTTokenException(str(e))
|
||||
payload = json.loads(
|
||||
rsa.decrypt(
|
||||
base64.b64decode(token.encode('ascii')),
|
||||
priv_key
|
||||
).decode('ascii')
|
||||
)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}')
|
||||
|
||||
expires_at = payload.get('expires_at')
|
||||
if expires_at and time.time() > expires_at:
|
||||
|
|
|
@ -7,6 +7,7 @@ import logging
|
|||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import rsa
|
||||
import signal
|
||||
import socket
|
||||
import ssl
|
||||
|
@ -15,6 +16,7 @@ from typing import Optional, Tuple, Union
|
|||
|
||||
from dateutil import parser, tz
|
||||
from redis import Redis
|
||||
from rsa.key import PublicKey, PrivateKey
|
||||
|
||||
logger = logging.getLogger('utils')
|
||||
|
||||
|
@ -366,7 +368,8 @@ def run(action, *args, **kwargs):
|
|||
return response.output
|
||||
|
||||
|
||||
def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> Tuple[str, str]:
|
||||
def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) \
|
||||
-> Tuple[PublicKey, PrivateKey]:
|
||||
"""
|
||||
Generate an RSA key pair.
|
||||
|
||||
|
@ -384,28 +387,11 @@ def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> T
|
|||
:param size: Key size (default: 2048 bits).
|
||||
:return: A tuple with the generated ``(priv_key_str, pub_key_str)``.
|
||||
"""
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
public_exp = 65537
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=public_exp,
|
||||
key_size=size,
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
logger.info('Generating RSA {} key pair'.format(size))
|
||||
private_key_str = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
).decode()
|
||||
|
||||
public_key_str = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.PKCS1,
|
||||
).decode()
|
||||
logger.info('Generating RSA keypair')
|
||||
pub_key, priv_key = rsa.newkeys(size)
|
||||
logger.info('Generated RSA keypair')
|
||||
public_key_str = pub_key.save_pkcs1('PEM').decode()
|
||||
private_key_str = priv_key.save_pkcs1('PEM').decode()
|
||||
|
||||
if key_file:
|
||||
logger.info('Saving private key to {}'.format(key_file))
|
||||
|
@ -415,7 +401,7 @@ def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> T
|
|||
f2.write(public_key_str)
|
||||
os.chmod(key_file, 0o600)
|
||||
|
||||
return public_key_str, private_key_str
|
||||
return pub_key, priv_key
|
||||
|
||||
|
||||
def get_or_generate_jwt_rsa_key_pair():
|
||||
|
@ -426,9 +412,14 @@ def get_or_generate_jwt_rsa_key_pair():
|
|||
pub_key_file = priv_key_file + '.pub'
|
||||
|
||||
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
|
||||
with open(pub_key_file, 'r') as f1, \
|
||||
open(priv_key_file, 'r') as f2:
|
||||
return f1.read(), f2.read()
|
||||
with (
|
||||
open(pub_key_file, 'r') as f1,
|
||||
open(priv_key_file, 'r') as f2
|
||||
):
|
||||
return (
|
||||
rsa.PublicKey.load_pkcs1(f1.read().encode()),
|
||||
rsa.PrivateKey.load_pkcs1(f2.read().encode()),
|
||||
)
|
||||
|
||||
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
||||
return generate_rsa_key_pair(priv_key_file, size=2048)
|
||||
|
|
|
@ -14,8 +14,7 @@ frozendict
|
|||
requests
|
||||
sqlalchemy
|
||||
bcrypt
|
||||
cryptography
|
||||
pyjwt
|
||||
rsa
|
||||
zeroconf
|
||||
paho-mqtt
|
||||
websocket-client
|
||||
|
|
3
setup.py
3
setup.py
|
@ -64,8 +64,7 @@ setup(
|
|||
'zeroconf>=0.27.0',
|
||||
'tz',
|
||||
'python-dateutil',
|
||||
# 'cryptography',
|
||||
'pyjwt',
|
||||
'rsa',
|
||||
'marshmallow',
|
||||
'frozendict',
|
||||
'flask',
|
||||
|
|
Loading…
Add table
Reference in a new issue