Added support for JWT token-based authentication

This commit is contained in:
Fabio Manganiello 2021-02-12 22:43:34 +01:00
parent 06ca5be54b
commit b3c28f6773
9 changed files with 287 additions and 12 deletions

View file

@ -0,0 +1,69 @@
import datetime
import json
import logging
from typing import Dict, Optional
from flask import Blueprint, request, abort, jsonify
from platypush.exceptions.user import UserException
from platypush.user import UserManager
auth = Blueprint('auth', __name__)
log = logging.getLogger(__name__)
# Declare routes list
__routes__ = [
auth,
]
@auth.route('/auth', methods=['POST'])
def auth_endpoint() -> Dict[str, Optional[str]]:
"""
Authentication endpoint. It validates the user credentials provided over a JSON payload with the following
structure:
.. code-block:: json
{
"username": "USERNAME",
"password": "PASSWORD",
"expiry_days": "The generated token should be valid for these many days"
}
``expiry_days`` is optional, and if omitted or set to zero the token will be valid indefinitely.
Upon successful validation, a new JWT token will be generated using the service's self-generated RSA key-pair and it
will be returned to the user. The token can then be used to authenticate API calls to ``/execute`` by setting the
``Authorization: Bearer <TOKEN_HERE>`` header upon HTTP calls.
:return: Return structure:
.. code-block:: json
{
"token": "<generated token here>"
}
"""
try:
payload = json.loads(request.get_data(as_text=True))
username, password = payload['username'], payload['password']
except Exception as e:
log.warning('Invalid payload passed to the auth endpoint: ' + str(e))
abort(400)
return jsonify({'token': None})
expiry_days = payload.get('expiry_days')
expires_at = None
if expiry_days:
expires_at = datetime.datetime.now() + datetime.timedelta(days=expiry_days)
user_manager = UserManager()
try:
return jsonify({
'token': user_manager.generate_jwt_token(username=username, password=password, expires_at=expires_at),
})
except UserException as e:
abort(401, str(e))
return jsonify({'token': None})

View file

@ -109,6 +109,7 @@ def send_request(action, wait_for_response=True, **kwargs):
def _authenticate_token(): def _authenticate_token():
token = Config.get('token') token = Config.get('token')
user_manager = UserManager()
if 'X-Token' in request.headers: if 'X-Token' in request.headers:
user_token = request.headers['X-Token'] user_token = request.headers['X-Token']
@ -119,7 +120,11 @@ def _authenticate_token():
else: else:
return False return False
return token and user_token == token try:
user_manager.validate_jwt_token(user_token)
return True
except:
return token and user_token == token
def _authenticate_http(): def _authenticate_http():
@ -179,7 +184,6 @@ def authenticate(redirect_page='', skip_auth_methods=None, check_csrf_token=Fals
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
user_manager = UserManager() user_manager = UserManager()
n_users = user_manager.get_user_count() n_users = user_manager.get_user_count()
token = Config.get('token')
skip_methods = skip_auth_methods or [] skip_methods = skip_auth_methods or []
# User/pass HTTP authentication # User/pass HTTP authentication
@ -191,7 +195,7 @@ def authenticate(redirect_page='', skip_auth_methods=None, check_csrf_token=Fals
# Token-based authentication # Token-based authentication
token_auth_ok = True token_auth_ok = True
if token and 'token' not in skip_methods: if 'token' not in skip_methods:
token_auth_ok = _authenticate_token() token_auth_ok = _authenticate_token()
if token_auth_ok: if token_auth_ok:
return f(*args, **kwargs) return f(*args, **kwargs)

View file

@ -0,0 +1,20 @@
from typing import Optional, Union
class PlatypushException(Exception):
"""
Base class for all Platypush exceptions.
"""
def __init__(self, error: Optional[Union[str, Exception]] = None, *args):
super().__init__(*args)
self._inner_exception = None
self._msg = None
if isinstance(error, str):
self._msg = error
elif isinstance(error, Exception):
self._inner_exception = error
self._msg = str(error)
def __str__(self):
return self._msg

View file

@ -0,0 +1,44 @@
from typing import Optional, Union
from platypush.exceptions import PlatypushException
class UserException(PlatypushException):
"""
Base class for all user exceptions.
"""
def __init__(self, *args, user: Optional[Union[str, int]] = None, **kwargs):
super().__init__(*args, **kwargs)
self.user = user
class AuthenticationException(UserException):
"""
Authentication error exception.
"""
def __init__(self, error='Unauthorized', *args, **kwargs):
super().__init__(error, *args, **kwargs)
class InvalidTokenException(AuthenticationException):
"""
Exception raised in case of wrong user token.
"""
def __init__(self, error='Invalid user token', *args, **kwargs):
super().__init__(error, *args, **kwargs)
class InvalidCredentialsException(AuthenticationException):
"""
Exception raised in case of wrong user token.
"""
def __init__(self, error='Invalid credentials', *args, **kwargs):
super().__init__(error, *args, **kwargs)
class InvalidJWTTokenException(InvalidTokenException):
"""
Exception raised in case of wrong/invalid JWT token.
"""
def __init__(self, error='Invalid JWT token', *args, **kwargs):
super().__init__(error, *args, **kwargs)

View file

@ -56,7 +56,7 @@ class UserPlugin(Plugin):
:return: True if the provided username and password are correct, False otherwise :return: True if the provided username and password are correct, False otherwise
""" """
return self.user_manager.authenticate_user(username, password) return True if self.user_manager.authenticate_user(username, password) else False
@action @action
def update_password(self, username, old_password, new_password): def update_password(self, username, old_password, new_password):

View file

@ -1,14 +1,19 @@
import datetime import datetime
import hashlib import hashlib
import random import random
import time
from typing import Optional, Dict
import bcrypt import bcrypt
import jwt
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from platypush.context import get_plugin from platypush.context import get_plugin
from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException
from platypush.utils import get_or_generate_jwt_rsa_key_pair
Base = declarative_base() Base = declarative_base()
@ -121,7 +126,8 @@ class UserManager:
def create_user_session(self, username, password, expires_at=None): def create_user_session(self, username, password, expires_at=None):
session = self._get_db_session() session = self._get_db_session()
if not self._authenticate_user(session, username, password): user = self._authenticate_user(session, username, password)
if not user:
return None return None
if expires_at: if expires_at:
@ -130,9 +136,8 @@ class UserManager:
elif isinstance(expires_at, str): elif isinstance(expires_at, str):
expires_at = datetime.datetime.fromisoformat(expires_at) expires_at = datetime.datetime.fromisoformat(expires_at)
user = self._get_user(session, username) user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
user_session = UserSession(user_id=user.user_id, session_token=self._generate_token(), csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
csrf_token=self._generate_token(), created_at=datetime.datetime.utcnow(),
expires_at=expires_at) expires_at=expires_at)
session.add(user_session) session.add(user_session)
@ -152,10 +157,64 @@ class UserManager:
return bcrypt.checkpw(pwd.encode(), hashed_pwd) return bcrypt.checkpw(pwd.encode(), hashed_pwd)
@staticmethod @staticmethod
def _generate_token(): def generate_session_token():
rand = bytes(random.randint(0, 255) for _ in range(0, 255)) rand = bytes(random.randint(0, 255) for _ in range(0, 255))
return hashlib.sha256(rand).hexdigest() return hashlib.sha256(rand).hexdigest()
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
"""
Create a user JWT token for API usage.
:param username: User name.
:param password: Password.
:param expires_at: Expiration datetime of the token.
:return: The generated JWT token as a string.
:raises: :class:`platypush.exceptions.user.InvalidCredentialsException` in case of invalid credentials.
"""
user = self.authenticate_user(username, password)
if not user:
raise InvalidCredentialsException()
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
payload = {
'username': username,
'created_at': datetime.datetime.now().timestamp(),
'expires_at': expires_at.timestamp() if expires_at else None,
}
return jwt.encode(payload, priv_key, algorithm='RS256').decode()
@staticmethod
def validate_jwt_token(token: str) -> Dict[str, str]:
"""
Validate a JWT token.
:param token: Token to validate.
:return: On success, it returns the JWT payload with the following structure:
.. code-block:: json
{
"username": "user ID/name",
"created_at": "token creation timestamp",
"expires_at": "token expiration timestamp"
}
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
"""
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
try:
payload = jwt.decode(token.encode(), pub_key, algorithms=['RS256'])
except jwt.exceptions.PyJWTError as e:
raise InvalidJWTTokenException(str(e))
expires_at = payload.get('expires_at')
if expires_at and time.time() > expires_at:
raise InvalidJWTTokenException('Expired JWT token')
return payload
def _get_db_session(self): def _get_db_session(self):
Base.metadata.create_all(self._engine) Base.metadata.create_all(self._engine)
session = scoped_session(sessionmaker()) session = scoped_session(sessionmaker())
@ -163,11 +222,17 @@ class UserManager:
return session() return session()
def _authenticate_user(self, session, username, password): def _authenticate_user(self, session, username, password):
"""
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.
"""
user = self._get_user(session, username) user = self._get_user(session, username)
if not user: if not user:
return False return None
return self._check_password(password, user.password) if not self._check_password(password, user.password):
return None
return user
class User(Base): class User(Base):

View file

@ -4,12 +4,13 @@ import importlib
import inspect import inspect
import logging import logging
import os import os
import pathlib
import re import re
import signal import signal
import socket import socket
import ssl import ssl
import urllib.request import urllib.request
from typing import Optional from typing import Optional, Tuple
logger = logging.getLogger('utils') logger = logging.getLogger('utils')
@ -351,4 +352,72 @@ def run(action, *args, **kwargs):
return response.output return response.output
def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> Tuple[str, str]:
"""
Generate an RSA key pair.
:param key_file: Target file for the private key (the associated public key will be stored in ``<key_file>.pub``.
If no key file is specified then the public and private keys will be returned in ASCII format in a dictionary
with the following structure:
.. code-block:: json
{
"private": "private key here",
"public": "public key here"
}
:param size: Key size (default: 2048 bits).
:return: A tuple with the generated ``(priv_key_str, pub_key_str)``.
"""
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
public_exp = 65537
private_key = rsa.generate_private_key(
public_exponent=public_exp,
key_size=size,
backend=default_backend()
)
logger.info('Generating RSA {} key pair'.format(size))
private_key_str = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
).decode()
public_key_str = private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1,
).decode()
if key_file:
logger.info('Saving private key to {}'.format(key_file))
with open(os.path.expanduser(key_file), 'w') as f1, \
open(os.path.expanduser(key_file) + '.pub', 'w') as f2:
f1.write(private_key_str)
f2.write(public_key_str)
os.chmod(key_file, 0o600)
return public_key_str, private_key_str
def get_or_generate_jwt_rsa_key_pair():
from platypush.config import Config
key_dir = os.path.join(Config.get('workdir'), 'jwt')
priv_key_file = os.path.join(key_dir, 'id_rsa')
pub_key_file = priv_key_file + '.pub'
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
with open(pub_key_file, 'r') as f1, \
open(priv_key_file, 'r') as f2:
return f1.read(), f2.read()
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
return generate_rsa_key_pair(priv_key_file, size=2048)
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

View file

@ -34,6 +34,8 @@ sqlalchemy
# Support for multi-users and password authentication # Support for multi-users and password authentication
bcrypt bcrypt
cryptography
pyjwt
# Support for Zeroconf/Bonjour # Support for Zeroconf/Bonjour
zeroconf zeroconf

View file

@ -65,6 +65,8 @@ setup(
'zeroconf>=0.27.0', 'zeroconf>=0.27.0',
'tz', 'tz',
'python-dateutil', 'python-dateutil',
'cryptography',
'pyjwt',
], ],
extras_require={ extras_require={