forked from platypush/platypush
Added support for JWT token-based authentication
This commit is contained in:
parent
06ca5be54b
commit
b3c28f6773
9 changed files with 287 additions and 12 deletions
69
platypush/backend/http/app/routes/auth.py
Normal file
69
platypush/backend/http/app/routes/auth.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from flask import Blueprint, request, abort, jsonify
|
||||||
|
|
||||||
|
from platypush.exceptions.user import UserException
|
||||||
|
from platypush.user import UserManager
|
||||||
|
|
||||||
|
auth = Blueprint('auth', __name__)
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Declare routes list
|
||||||
|
__routes__ = [
|
||||||
|
auth,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@auth.route('/auth', methods=['POST'])
|
||||||
|
def auth_endpoint() -> Dict[str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Authentication endpoint. It validates the user credentials provided over a JSON payload with the following
|
||||||
|
structure:
|
||||||
|
|
||||||
|
.. code-block:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"username": "USERNAME",
|
||||||
|
"password": "PASSWORD",
|
||||||
|
"expiry_days": "The generated token should be valid for these many days"
|
||||||
|
}
|
||||||
|
|
||||||
|
``expiry_days`` is optional, and if omitted or set to zero the token will be valid indefinitely.
|
||||||
|
|
||||||
|
Upon successful validation, a new JWT token will be generated using the service's self-generated RSA key-pair and it
|
||||||
|
will be returned to the user. The token can then be used to authenticate API calls to ``/execute`` by setting the
|
||||||
|
``Authorization: Bearer <TOKEN_HERE>`` header upon HTTP calls.
|
||||||
|
|
||||||
|
:return: Return structure:
|
||||||
|
|
||||||
|
.. code-block:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "<generated token here>"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = json.loads(request.get_data(as_text=True))
|
||||||
|
username, password = payload['username'], payload['password']
|
||||||
|
except Exception as e:
|
||||||
|
log.warning('Invalid payload passed to the auth endpoint: ' + str(e))
|
||||||
|
abort(400)
|
||||||
|
return jsonify({'token': None})
|
||||||
|
|
||||||
|
expiry_days = payload.get('expiry_days')
|
||||||
|
expires_at = None
|
||||||
|
if expiry_days:
|
||||||
|
expires_at = datetime.datetime.now() + datetime.timedelta(days=expiry_days)
|
||||||
|
|
||||||
|
user_manager = UserManager()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return jsonify({
|
||||||
|
'token': user_manager.generate_jwt_token(username=username, password=password, expires_at=expires_at),
|
||||||
|
})
|
||||||
|
except UserException as e:
|
||||||
|
abort(401, str(e))
|
||||||
|
return jsonify({'token': None})
|
|
@ -109,6 +109,7 @@ def send_request(action, wait_for_response=True, **kwargs):
|
||||||
|
|
||||||
def _authenticate_token():
|
def _authenticate_token():
|
||||||
token = Config.get('token')
|
token = Config.get('token')
|
||||||
|
user_manager = UserManager()
|
||||||
|
|
||||||
if 'X-Token' in request.headers:
|
if 'X-Token' in request.headers:
|
||||||
user_token = request.headers['X-Token']
|
user_token = request.headers['X-Token']
|
||||||
|
@ -119,6 +120,10 @@ def _authenticate_token():
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_manager.validate_jwt_token(user_token)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
return token and user_token == token
|
return token and user_token == token
|
||||||
|
|
||||||
|
|
||||||
|
@ -179,7 +184,6 @@ def authenticate(redirect_page='', skip_auth_methods=None, check_csrf_token=Fals
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
user_manager = UserManager()
|
user_manager = UserManager()
|
||||||
n_users = user_manager.get_user_count()
|
n_users = user_manager.get_user_count()
|
||||||
token = Config.get('token')
|
|
||||||
skip_methods = skip_auth_methods or []
|
skip_methods = skip_auth_methods or []
|
||||||
|
|
||||||
# User/pass HTTP authentication
|
# User/pass HTTP authentication
|
||||||
|
@ -191,7 +195,7 @@ def authenticate(redirect_page='', skip_auth_methods=None, check_csrf_token=Fals
|
||||||
|
|
||||||
# Token-based authentication
|
# Token-based authentication
|
||||||
token_auth_ok = True
|
token_auth_ok = True
|
||||||
if token and 'token' not in skip_methods:
|
if 'token' not in skip_methods:
|
||||||
token_auth_ok = _authenticate_token()
|
token_auth_ok = _authenticate_token()
|
||||||
if token_auth_ok:
|
if token_auth_ok:
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
20
platypush/exceptions/__init__.py
Normal file
20
platypush/exceptions/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class PlatypushException(Exception):
|
||||||
|
"""
|
||||||
|
Base class for all Platypush exceptions.
|
||||||
|
"""
|
||||||
|
def __init__(self, error: Optional[Union[str, Exception]] = None, *args):
|
||||||
|
super().__init__(*args)
|
||||||
|
self._inner_exception = None
|
||||||
|
self._msg = None
|
||||||
|
|
||||||
|
if isinstance(error, str):
|
||||||
|
self._msg = error
|
||||||
|
elif isinstance(error, Exception):
|
||||||
|
self._inner_exception = error
|
||||||
|
self._msg = str(error)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self._msg
|
44
platypush/exceptions/user.py
Normal file
44
platypush/exceptions/user.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from platypush.exceptions import PlatypushException
|
||||||
|
|
||||||
|
|
||||||
|
class UserException(PlatypushException):
|
||||||
|
"""
|
||||||
|
Base class for all user exceptions.
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, user: Optional[Union[str, int]] = None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.user = user
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationException(UserException):
|
||||||
|
"""
|
||||||
|
Authentication error exception.
|
||||||
|
"""
|
||||||
|
def __init__(self, error='Unauthorized', *args, **kwargs):
|
||||||
|
super().__init__(error, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTokenException(AuthenticationException):
|
||||||
|
"""
|
||||||
|
Exception raised in case of wrong user token.
|
||||||
|
"""
|
||||||
|
def __init__(self, error='Invalid user token', *args, **kwargs):
|
||||||
|
super().__init__(error, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidCredentialsException(AuthenticationException):
|
||||||
|
"""
|
||||||
|
Exception raised in case of wrong user token.
|
||||||
|
"""
|
||||||
|
def __init__(self, error='Invalid credentials', *args, **kwargs):
|
||||||
|
super().__init__(error, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidJWTTokenException(InvalidTokenException):
|
||||||
|
"""
|
||||||
|
Exception raised in case of wrong/invalid JWT token.
|
||||||
|
"""
|
||||||
|
def __init__(self, error='Invalid JWT token', *args, **kwargs):
|
||||||
|
super().__init__(error, *args, **kwargs)
|
|
@ -56,7 +56,7 @@ class UserPlugin(Plugin):
|
||||||
:return: True if the provided username and password are correct, False otherwise
|
:return: True if the provided username and password are correct, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.user_manager.authenticate_user(username, password)
|
return True if self.user_manager.authenticate_user(username, password) else False
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def update_password(self, username, old_password, new_password):
|
def update_password(self, username, old_password, new_password):
|
||||||
|
|
|
@ -1,14 +1,19 @@
|
||||||
import datetime
|
import datetime
|
||||||
import hashlib
|
import hashlib
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
import jwt
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
|
from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException
|
||||||
|
from platypush.utils import get_or_generate_jwt_rsa_key_pair
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
@ -121,7 +126,8 @@ class UserManager:
|
||||||
|
|
||||||
def create_user_session(self, username, password, expires_at=None):
|
def create_user_session(self, username, password, expires_at=None):
|
||||||
session = self._get_db_session()
|
session = self._get_db_session()
|
||||||
if not self._authenticate_user(session, username, password):
|
user = self._authenticate_user(session, username, password)
|
||||||
|
if not user:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if expires_at:
|
if expires_at:
|
||||||
|
@ -130,9 +136,8 @@ class UserManager:
|
||||||
elif isinstance(expires_at, str):
|
elif isinstance(expires_at, str):
|
||||||
expires_at = datetime.datetime.fromisoformat(expires_at)
|
expires_at = datetime.datetime.fromisoformat(expires_at)
|
||||||
|
|
||||||
user = self._get_user(session, username)
|
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
|
||||||
user_session = UserSession(user_id=user.user_id, session_token=self._generate_token(),
|
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
|
||||||
csrf_token=self._generate_token(), created_at=datetime.datetime.utcnow(),
|
|
||||||
expires_at=expires_at)
|
expires_at=expires_at)
|
||||||
|
|
||||||
session.add(user_session)
|
session.add(user_session)
|
||||||
|
@ -152,10 +157,64 @@ class UserManager:
|
||||||
return bcrypt.checkpw(pwd.encode(), hashed_pwd)
|
return bcrypt.checkpw(pwd.encode(), hashed_pwd)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_token():
|
def generate_session_token():
|
||||||
rand = bytes(random.randint(0, 255) for _ in range(0, 255))
|
rand = bytes(random.randint(0, 255) for _ in range(0, 255))
|
||||||
return hashlib.sha256(rand).hexdigest()
|
return hashlib.sha256(rand).hexdigest()
|
||||||
|
|
||||||
|
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
|
||||||
|
"""
|
||||||
|
Create a user JWT token for API usage.
|
||||||
|
|
||||||
|
:param username: User name.
|
||||||
|
:param password: Password.
|
||||||
|
:param expires_at: Expiration datetime of the token.
|
||||||
|
:return: The generated JWT token as a string.
|
||||||
|
:raises: :class:`platypush.exceptions.user.InvalidCredentialsException` in case of invalid credentials.
|
||||||
|
"""
|
||||||
|
user = self.authenticate_user(username, password)
|
||||||
|
if not user:
|
||||||
|
raise InvalidCredentialsException()
|
||||||
|
|
||||||
|
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
|
||||||
|
payload = {
|
||||||
|
'username': username,
|
||||||
|
'created_at': datetime.datetime.now().timestamp(),
|
||||||
|
'expires_at': expires_at.timestamp() if expires_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(payload, priv_key, algorithm='RS256').decode()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_jwt_token(token: str) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Validate a JWT token.
|
||||||
|
|
||||||
|
:param token: Token to validate.
|
||||||
|
:return: On success, it returns the JWT payload with the following structure:
|
||||||
|
|
||||||
|
.. code-block:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"username": "user ID/name",
|
||||||
|
"created_at": "token creation timestamp",
|
||||||
|
"expires_at": "token expiration timestamp"
|
||||||
|
}
|
||||||
|
|
||||||
|
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
|
||||||
|
"""
|
||||||
|
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token.encode(), pub_key, algorithms=['RS256'])
|
||||||
|
except jwt.exceptions.PyJWTError as e:
|
||||||
|
raise InvalidJWTTokenException(str(e))
|
||||||
|
|
||||||
|
expires_at = payload.get('expires_at')
|
||||||
|
if expires_at and time.time() > expires_at:
|
||||||
|
raise InvalidJWTTokenException('Expired JWT token')
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
def _get_db_session(self):
|
def _get_db_session(self):
|
||||||
Base.metadata.create_all(self._engine)
|
Base.metadata.create_all(self._engine)
|
||||||
session = scoped_session(sessionmaker())
|
session = scoped_session(sessionmaker())
|
||||||
|
@ -163,11 +222,17 @@ class UserManager:
|
||||||
return session()
|
return session()
|
||||||
|
|
||||||
def _authenticate_user(self, session, username, password):
|
def _authenticate_user(self, session, username, password):
|
||||||
|
"""
|
||||||
|
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.
|
||||||
|
"""
|
||||||
user = self._get_user(session, username)
|
user = self._get_user(session, username)
|
||||||
if not user:
|
if not user:
|
||||||
return False
|
return None
|
||||||
|
|
||||||
return self._check_password(password, user.password)
|
if not self._check_password(password, user.password):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
|
|
|
@ -4,12 +4,13 @@ import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import signal
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
logger = logging.getLogger('utils')
|
logger = logging.getLogger('utils')
|
||||||
|
|
||||||
|
@ -351,4 +352,72 @@ def run(action, *args, **kwargs):
|
||||||
return response.output
|
return response.output
|
||||||
|
|
||||||
|
|
||||||
|
def generate_rsa_key_pair(key_file: Optional[str] = None, size: int = 2048) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Generate an RSA key pair.
|
||||||
|
|
||||||
|
:param key_file: Target file for the private key (the associated public key will be stored in ``<key_file>.pub``.
|
||||||
|
If no key file is specified then the public and private keys will be returned in ASCII format in a dictionary
|
||||||
|
with the following structure:
|
||||||
|
|
||||||
|
.. code-block:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"private": "private key here",
|
||||||
|
"public": "public key here"
|
||||||
|
}
|
||||||
|
|
||||||
|
:param size: Key size (default: 2048 bits).
|
||||||
|
:return: A tuple with the generated ``(priv_key_str, pub_key_str)``.
|
||||||
|
"""
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
|
||||||
|
public_exp = 65537
|
||||||
|
private_key = rsa.generate_private_key(
|
||||||
|
public_exponent=public_exp,
|
||||||
|
key_size=size,
|
||||||
|
backend=default_backend()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info('Generating RSA {} key pair'.format(size))
|
||||||
|
private_key_str = private_key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||||
|
encryption_algorithm=serialization.NoEncryption()
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
public_key_str = private_key.public_key().public_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PublicFormat.PKCS1,
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
if key_file:
|
||||||
|
logger.info('Saving private key to {}'.format(key_file))
|
||||||
|
with open(os.path.expanduser(key_file), 'w') as f1, \
|
||||||
|
open(os.path.expanduser(key_file) + '.pub', 'w') as f2:
|
||||||
|
f1.write(private_key_str)
|
||||||
|
f2.write(public_key_str)
|
||||||
|
os.chmod(key_file, 0o600)
|
||||||
|
|
||||||
|
return public_key_str, private_key_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_generate_jwt_rsa_key_pair():
|
||||||
|
from platypush.config import Config
|
||||||
|
|
||||||
|
key_dir = os.path.join(Config.get('workdir'), 'jwt')
|
||||||
|
priv_key_file = os.path.join(key_dir, 'id_rsa')
|
||||||
|
pub_key_file = priv_key_file + '.pub'
|
||||||
|
|
||||||
|
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
|
||||||
|
with open(pub_key_file, 'r') as f1, \
|
||||||
|
open(priv_key_file, 'r') as f2:
|
||||||
|
return f1.read(), f2.read()
|
||||||
|
|
||||||
|
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
||||||
|
return generate_rsa_key_pair(priv_key_file, size=2048)
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
|
@ -34,6 +34,8 @@ sqlalchemy
|
||||||
|
|
||||||
# Support for multi-users and password authentication
|
# Support for multi-users and password authentication
|
||||||
bcrypt
|
bcrypt
|
||||||
|
cryptography
|
||||||
|
pyjwt
|
||||||
|
|
||||||
# Support for Zeroconf/Bonjour
|
# Support for Zeroconf/Bonjour
|
||||||
zeroconf
|
zeroconf
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -65,6 +65,8 @@ setup(
|
||||||
'zeroconf>=0.27.0',
|
'zeroconf>=0.27.0',
|
||||||
'tz',
|
'tz',
|
||||||
'python-dateutil',
|
'python-dateutil',
|
||||||
|
'cryptography',
|
||||||
|
'pyjwt',
|
||||||
],
|
],
|
||||||
|
|
||||||
extras_require={
|
extras_require={
|
||||||
|
|
Loading…
Reference in a new issue