Added support for JWT token-based authentication

This commit is contained in:
Fabio Manganiello 2021-02-12 22:43:34 +01:00
parent 06ca5be54b
commit b3c28f6773
9 changed files with 287 additions and 12 deletions

View file

@ -0,0 +1,69 @@
import datetime
import json
import logging
from typing import Dict, Optional
from flask import Blueprint, request, abort, jsonify
from platypush.exceptions.user import UserException
from platypush.user import UserManager
auth = Blueprint('auth', __name__)
log = logging.getLogger(__name__)
# Declare routes list
__routes__ = [
auth,
]
@auth.route('/auth', methods=['POST'])
def auth_endpoint() -> Dict[str, Optional[str]]:
"""
Authentication endpoint. It validates the user credentials provided over a JSON payload with the following
structure:
.. code-block:: json
{
"username": "USERNAME",
"password": "PASSWORD",
"expiry_days": "The generated token should be valid for these many days"
}
``expiry_days`` is optional, and if omitted or set to zero the token will be valid indefinitely.
Upon successful validation, a new JWT token will be generated using the service's self-generated RSA key-pair and it
will be returned to the user. The token can then be used to authenticate API calls to ``/execute`` by setting the
``Authorization: Bearer <TOKEN_HERE>`` header upon HTTP calls.
:return: Return structure:
.. code-block:: json
{
"token": "<generated token here>"
}
"""
try:
payload = json.loads(request.get_data(as_text=True))
username, password = payload['username'], payload['password']
except Exception as e:
log.warning('Invalid payload passed to the auth endpoint: ' + str(e))
abort(400)
return jsonify({'token': None})
expiry_days = payload.get('expiry_days')
expires_at = None
if expiry_days:
expires_at = datetime.datetime.now() + datetime.timedelta(days=expiry_days)
user_manager = UserManager()
try:
return jsonify({
'token': user_manager.generate_jwt_token(username=username, password=password, expires_at=expires_at),
})
except UserException as e:
abort(401, str(e))
return jsonify({'token': None})

View file

@ -109,6 +109,7 @@ def send_request(action, wait_for_response=True, **kwargs):
def _authenticate_token():
token = Config.get('token')
user_manager = UserManager()
if 'X-Token' in request.headers:
user_token = request.headers['X-Token']
@ -119,6 +120,10 @@ def _authenticate_token():
else:
return False
try:
user_manager.validate_jwt_token(user_token)
return True
except:
return token and user_token == token
@ -179,7 +184,6 @@ def authenticate(redirect_page='', skip_auth_methods=None, check_csrf_token=Fals
def wrapper(*args, **kwargs):
user_manager = UserManager()
n_users = user_manager.get_user_count()
token = Config.get('token')
skip_methods = skip_auth_methods or []
# User/pass HTTP authentication
@ -191,7 +195,7 @@ def authenticate(redirect_page='', skip_auth_methods=None, check_csrf_token=Fals
# Token-based authentication
token_auth_ok = True
if token and 'token' not in skip_methods:
if 'token' not in skip_methods:
token_auth_ok = _authenticate_token()
if token_auth_ok:
return f(*args, **kwargs)

View file

@ -0,0 +1,20 @@
from typing import Optional, Union
class PlatypushException(Exception):
"""
Base class for all Platypush exceptions.
"""
def __init__(self, error: Optional[Union[str, Exception]] = None, *args):
super().__init__(*args)
self._inner_exception = None
self._msg = None
if isinstance(error, str):
self._msg = error
elif isinstance(error, Exception):
self._inner_exception = error
self._msg = str(error)
def __str__(self):
return self._msg

View file

@ -0,0 +1,44 @@
from typing import Optional, Union
from platypush.exceptions import PlatypushException
class UserException(PlatypushException):
"""
Base class for all user exceptions.
"""
def __init__(self, *args, user: Optional[Union[str, int]] = None, **kwargs):
super().__init__(*args, **kwargs)
self.user = user
class AuthenticationException(UserException):
"""
Authentication error exception.
"""
def __init__(self, error='Unauthorized', *args, **kwargs):
super().__init__(error, *args, **kwargs)
class InvalidTokenException(AuthenticationException):
"""
Exception raised in case of wrong user token.
"""
def __init__(self, error='Invalid user token', *args, **kwargs):
super().__init__(error, *args, **kwargs)
class InvalidCredentialsException(AuthenticationException):
"""
Exception raised in case of wrong user token.
"""
def __init__(self, error='Invalid credentials', *args, **kwargs):
super().__init__(error, *args, **kwargs)
class InvalidJWTTokenException(InvalidTokenException):
"""
Exception raised in case of wrong/invalid JWT token.
"""
def __init__(self, error='Invalid JWT token', *args, **kwargs):
super().__init__(error, *args, **kwargs)

View file

@ -56,7 +56,7 @@ class UserPlugin(Plugin):
:return: True if the provided username and password are correct, False otherwise
"""
return self.user_manager.authenticate_user(username, password)
return True if self.user_manager.authenticate_user(username, password) else False
@action
def update_password(self, username, old_password, new_password):

View file

@ -1,14 +1,19 @@
import datetime
import hashlib
import random
import time
from typing import Optional, Dict
import bcrypt
import jwt
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from platypush.context import get_plugin
from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException
from platypush.utils import get_or_generate_jwt_rsa_key_pair
Base = declarative_base()
@ -121,7 +126,8 @@ class UserManager:
def create_user_session(self, username, password, expires_at=None):
session = self._get_db_session()
if not self._authenticate_user(session, username, password):
user = self._authenticate_user(session, username, password)
if not user:
return None
if expires_at:
@ -130,9 +136,8 @@ class UserManager:
elif isinstance(expires_at, str):
expires_at = datetime.datetime.fromisoformat(expires_at)
user = self._get_user(session, username)
user_session = UserSession(user_id=user.user_id, session_token=self._generate_token(),
csrf_token=self._generate_token(), created_at=datetime.datetime.utcnow(),
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
expires_at=expires_at)
session.add(user_session)
@ -152,10 +157,64 @@ class UserManager:
return bcrypt.checkpw(pwd.encode(), hashed_pwd)
@staticmethod
def _generate_token():
def generate_session_token():
rand = bytes(random.randint(0, 255) for _ in range(0, 255))
return hashlib.sha256(rand).hexdigest()
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
"""
Create a user JWT token for API usage.
:param username: User name.
:param password: Password.
:param expires_at: Expiration datetime of the token.
:return: The generated JWT token as a string.
:raises: :class:`platypush.exceptions.user.InvalidCredentialsException` in case of invalid credentials.
"""
user = self.authenticate_user(username, password)
if not user:
raise InvalidCredentialsException()
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
payload = {
'username': username,
'created_at': datetime.datetime.now().timestamp(),
'expires_at': expires_at.timestamp() if expires_at else None,
}
return jwt.encode(payload, priv_key, algorithm='RS256').decode()
@staticmethod
def validate_jwt_token(token: str) -> Dict[str, str]:
"""
Validate a JWT token.
:param token: Token to validate.
:return: On success, it returns the JWT payload with the following structure:
.. code-block:: json
{
"username": "user ID/name",
"created_at": "token creation timestamp",
"expires_at": "token expiration timestamp"
}
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
"""
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
try:
payload = jwt.decode(token.encode(), pub_key, algorithms=['RS256'])
except jwt.exceptions.PyJWTError as e:
raise InvalidJWTTokenException(str(e))
expires_at = payload.get('expires_at')
if expires_at and time.time() > expires_at:
raise InvalidJWTTokenException('Expired JWT token')
return payload
def _get_db_session(self):
Base.metadata.create_all(self._engine)
session = scoped_session(sessionmaker())
@ -163,11 +222,17 @@ class UserManager:
return session()
def _authenticate_user(self, session, username, password):
"""
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.
"""
user = self._get_user(session, username)
if not user:
return False
return None
return self._check_password(password, user.password)
if not self._check_password(password, user.password):
return None
return user
class User(Base):

View file

@ -4,12 +4,13 @@ import importlib
import inspect
import logging
import os
import pathlib
import re
import signal
import socket
import ssl
import urllib.request
from typing import Optional
from typing import Optional, Tuple
logger = logging.getLogger('utils')
@ -351,4 +352,72 @@ def run(action, *args, **kwargs):
return response.output
def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> Tuple[str, str]:
"""
Generate an RSA key pair.
:param key_file: Target file for the private key (the associated public key will be stored in ``<key_file>.pub``.
If no key file is specified then the public and private keys will be returned in ASCII format in a dictionary
with the following structure:
.. code-block:: json
{
"private": "private key here",
"public": "public key here"
}
:param size: Key size (default: 2048 bits).
:return: A tuple with the generated ``(priv_key_str, pub_key_str)``.
"""
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
public_exp = 65537
private_key = rsa.generate_private_key(
public_exponent=public_exp,
key_size=size,
backend=default_backend()
)
logger.info('Generating RSA {} key pair'.format(size))
private_key_str = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode()
public_key_str = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1,
).decode()
if key_file:
logger.info('Saving private key to {}'.format(key_file))
with open(os.path.expanduser(key_file), 'w') as f1, \
open(os.path.expanduser(key_file) + '.pub', 'w') as f2:
f1.write(private_key_str)
f2.write(public_key_str)
os.chmod(key_file, 0o600)
return public_key_str, private_key_str
def get_or_generate_jwt_rsa_key_pair():
from platypush.config import Config
key_dir = os.path.join(Config.get('workdir'), 'jwt')
priv_key_file = os.path.join(key_dir, 'id_rsa')
pub_key_file = priv_key_file + '.pub'
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
with open(pub_key_file, 'r') as f1, \
open(priv_key_file, 'r') as f2:
return f1.read(), f2.read()
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
return generate_rsa_key_pair(priv_key_file, size=2048)
# vim:sw=4:ts=4:et:

View file

@ -34,6 +34,8 @@ sqlalchemy
# Support for multi-users and password authentication
bcrypt
cryptography
pyjwt
# Support for Zeroconf/Bonjour
zeroconf

View file

@ -65,6 +65,8 @@ setup(
'zeroconf>=0.27.0',
'tz',
'python-dateutil',
'cryptography',
'pyjwt',
],
extras_require={