diff --git a/platypush/backend/http/app/routes/auth.py b/platypush/backend/http/app/routes/auth.py index cfdc30e7e0..015c6c4157 100644 --- a/platypush/backend/http/app/routes/auth.py +++ b/platypush/backend/http/app/routes/auth.py @@ -4,8 +4,10 @@ import logging from flask import Blueprint, request, abort, jsonify +from platypush.backend.http.app.utils.auth import UserAuthStatus from platypush.exceptions.user import UserException from platypush.user import UserManager +from platypush.utils import utcnow auth = Blueprint('auth', __name__) log = logging.getLogger(__name__) @@ -16,39 +18,24 @@ __routes__ = [ ] -@auth.route('/auth', methods=['POST']) -def auth_endpoint(): - """ - Authentication endpoint. It validates the user credentials provided over a JSON payload with the following - structure: +def _dump_session(session, redirect_page='/'): + return jsonify( + { + 'status': 'ok', + 'user_id': session.user_id, + 'session_token': session.session_token, + 'expires_at': session.expires_at, + 'redirect': redirect_page, + } + ) - .. code-block:: json - { - "username": "USERNAME", - "password": "PASSWORD", - "expiry_days": "The generated token should be valid for these many days" - } - - ``expiry_days`` is optional, and if omitted or set to zero the token will be valid indefinitely. - - Upon successful validation, a new JWT token will be generated using the service's self-generated RSA key-pair and it - will be returned to the user. The token can then be used to authenticate API calls to ``/execute`` by setting the - ``Authorization: Bearer `` header upon HTTP calls. - - :return: Return structure: - - .. code-block:: json - - { - "token": "" - } - """ +def _jwt_auth(): try: payload = json.loads(request.get_data(as_text=True)) username, password = payload['username'], payload['password'] - except Exception as e: - log.warning('Invalid payload passed to the auth endpoint: ' + str(e)) + except Exception: + log.warning('Invalid payload passed to the auth endpoint') abort(400) expiry_days = payload.get('expiry_days') @@ -59,8 +46,174 @@ def auth_endpoint(): user_manager = UserManager() try: - return jsonify({ - 'token': user_manager.generate_jwt_token(username=username, password=password, expires_at=expires_at), - }) + return jsonify( + { + 'token': user_manager.generate_jwt_token( + username=username, password=password, expires_at=expires_at + ), + } + ) except UserException as e: abort(401, str(e)) + + +def _session_auth(): + user_manager = UserManager() + session_token = request.cookies.get('session_token') + redirect_page = request.args.get('redirect') or '/' + + if session_token: + user, session = user_manager.authenticate_user_session(session_token) # type: ignore + if user and session: + return _dump_session(session, redirect_page) + + if request.form: + username = request.form.get('username') + password = request.form.get('password') + code = request.form.get('code') + remember = request.form.get('remember') + expires = utcnow() + datetime.timedelta(days=365) if remember else None + session, status = user_manager.create_user_session( # type: ignore + username=username, + password=password, + code=code, + expires_at=expires, + error_on_invalid_credentials=True, + ) + + if session: + return _dump_session(session, redirect_page) + + if status: + return status.to_response() # type: ignore + + return UserAuthStatus.INVALID_CREDENTIALS.to_response() + + +def _register_route(): + """Registration endpoint""" + user_manager = UserManager() + session_token = request.cookies.get('session_token') + redirect_page = request.args.get('redirect') or '/' + + if session_token: + user, session = user_manager.authenticate_user_session(session_token) # type: ignore + if user and session: + return _dump_session(session, redirect_page) + + if user_manager.get_user_count() > 0: + return UserAuthStatus.REGISTRATION_DISABLED.to_response() + + if not request.form: + return UserAuthStatus.MISSING_USERNAME.to_response() + + username = request.form.get('username') + password = request.form.get('password') + confirm_password = request.form.get('confirm_password') + remember = request.form.get('remember') + + if not username: + return UserAuthStatus.MISSING_USERNAME.to_response() + if not password: + return UserAuthStatus.MISSING_PASSWORD.to_response() + if password != confirm_password: + return UserAuthStatus.PASSWORD_MISMATCH.to_response() + + user_manager.create_user(username=username, password=password) + session, status = user_manager.create_user_session( # type: ignore + username=username, + password=password, + expires_at=(utcnow() + datetime.timedelta(days=365) if remember else None), + error_on_invalid_credentials=True, + ) + + if session: + return _dump_session(session, redirect_page) + + if status: + return status.to_response() # type: ignore + + return UserAuthStatus.INVALID_CREDENTIALS.to_response() + + +def _auth_get(): + """ + Get the current authentication status of the user session. + """ + user_manager = UserManager() + session_token = request.cookies.get('session_token') + redirect_page = request.args.get('redirect') or '/' + user, session, status = user_manager.authenticate_user_session( # type: ignore + session_token, with_error=True + ) + + if user and session: + return _dump_session(session, redirect_page) + + if status: + return UserAuthStatus.by_status(status).to_response() # type: ignore + + return UserAuthStatus.INVALID_CREDENTIALS.to_response() + + +def _auth_post(): + """ + Authenticate the user session. + """ + auth_type = request.args.get('type') or 'jwt' + + if auth_type == 'jwt': + return _jwt_auth() + + if auth_type == 'register': + return _register_route() + + if auth_type == 'login': + return _session_auth() + + return UserAuthStatus.INVALID_AUTH_TYPE.to_response() + + +@auth.route('/auth', methods=['GET', 'POST']) +def auth_endpoint(): + """ + Authentication endpoint. It validates the user credentials provided over a + JSON payload with the following structure: + + .. code-block:: json + + { + "username": "USERNAME", + "password": "PASSWORD", + "code": "2FA_CODE", + "expiry_days": "The generated token should be valid for these many days" + } + + ``expiry_days`` is optional, and if omitted or set to zero the token will + be valid indefinitely. + + Upon successful validation, a new JWT token will be generated using the + service's self-generated RSA key-pair and it will be returned to the user. + The token can then be used to authenticate API calls to ``/execute`` by + setting the ``Authorization: Bearer `` header upon HTTP calls. + + :return: Return structure: + + .. code-block:: json + + { + "token": "" + } + """ + if request.method == 'GET': + return _auth_get() + + if request.method == 'POST': + return _auth_post() + + return UserAuthStatus.INVALID_METHOD.to_response() + + +# from flask import Blueprint, request, redirect, render_template, make_response, abort + +# vim:sw=4:ts=4:et: diff --git a/platypush/backend/http/app/routes/login.py b/platypush/backend/http/app/routes/login.py deleted file mode 100644 index 94f19b2eba..0000000000 --- a/platypush/backend/http/app/routes/login.py +++ /dev/null @@ -1,56 +0,0 @@ -import datetime -import re - -from flask import Blueprint, request, redirect, render_template, make_response - -from platypush.backend.http.app import template_folder -from platypush.backend.http.utils import HttpUtils -from platypush.user import UserManager -from platypush.utils import utcnow - -login = Blueprint('login', __name__, template_folder=template_folder) - -# Declare routes list -__routes__ = [ - login, -] - - -@login.route('/login', methods=['GET', 'POST']) -def login(): - """Login page""" - user_manager = UserManager() - session_token = request.cookies.get('session_token') - - redirect_page = request.args.get('redirect') - if not redirect_page: - redirect_page = request.headers.get('Referer', '/') - if re.search('(^https?://[^/]+)?/login[^?#]?', redirect_page): - # Prevent redirect loop - redirect_page = '/' - - if session_token: - user, session = user_manager.authenticate_user_session(session_token) - if user: - return redirect(redirect_page, 302) # lgtm [py/url-redirection] - - if request.form: - username = request.form.get('username') - password = request.form.get('password') - remember = request.form.get('remember') - expires = utcnow() + datetime.timedelta(days=365) if remember else None - - session = user_manager.create_user_session( - username=username, password=password, expires_at=expires - ) - - if session: - redirect_target = redirect(redirect_page, 302) # lgtm [py/url-redirection] - response = make_response(redirect_target) - response.set_cookie('session_token', session.session_token, expires=expires) - return response - - return render_template('index.html', utils=HttpUtils) - - -# vim:sw=4:ts=4:et: diff --git a/platypush/backend/http/app/routes/register.py b/platypush/backend/http/app/routes/register.py deleted file mode 100644 index 0150d8be53..0000000000 --- a/platypush/backend/http/app/routes/register.py +++ /dev/null @@ -1,71 +0,0 @@ -import datetime -import re - -from flask import Blueprint, request, redirect, render_template, make_response, abort - -from platypush.backend.http.app import template_folder -from platypush.backend.http.utils import HttpUtils -from platypush.user import UserManager -from platypush.utils import utcnow - -register = Blueprint('register', __name__, template_folder=template_folder) - -# Declare routes list -__routes__ = [ - register, -] - - -@register.route('/register', methods=['GET', 'POST']) -def register(): - """Registration page""" - user_manager = UserManager() - redirect_page = request.args.get('redirect') - if not redirect_page: - redirect_page = request.headers.get('Referer', '/') - if re.search('(^https?://[^/]+)?/register[^?#]?', redirect_page): - # Prevent redirect loop - redirect_page = '/' - - session_token = request.cookies.get('session_token') - - if session_token: - user, session = user_manager.authenticate_user_session(session_token) - if user: - return redirect(redirect_page, 302) # lgtm [py/url-redirection] - - if user_manager.get_user_count() > 0: - return redirect( - '/login?redirect=' + redirect_page, 302 - ) # lgtm [py/url-redirection] - - if request.form: - username = request.form.get('username') - password = request.form.get('password') - confirm_password = request.form.get('confirm_password') - remember = request.form.get('remember') - - if password == confirm_password: - user_manager.create_user(username=username, password=password) - session = user_manager.create_user_session( - username=username, - password=password, - expires_at=( - utcnow() + datetime.timedelta(days=1) if not remember else None - ), - ) - - if session: - redirect_target = redirect( - redirect_page, 302 - ) # lgtm [py/url-redirection] - response = make_response(redirect_target) - response.set_cookie('session_token', session.session_token) - return response - else: - abort(400, 'Password mismatch') - - return render_template('index.html', utils=HttpUtils) - - -# vim:sw=4:ts=4:et: diff --git a/platypush/backend/http/app/streaming/_base.py b/platypush/backend/http/app/streaming/_base.py index 862550324f..962b54b4e5 100644 --- a/platypush/backend/http/app/streaming/_base.py +++ b/platypush/backend/http/app/streaming/_base.py @@ -7,7 +7,7 @@ from typing import Optional from tornado.web import RequestHandler, stream_request_body from platypush.backend.http.app.utils import logger -from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status +from platypush.backend.http.app.utils.auth import UserAuthStatus, get_auth_status from ..mixins import PubSubMixin @@ -29,7 +29,7 @@ class StreamingRoute(RequestHandler, PubSubMixin, ABC): """ if self.auth_required: auth_status = get_auth_status(self.request) - if auth_status != AuthStatus.OK: + if auth_status != UserAuthStatus.OK: self.send_error(auth_status.value.code, error=auth_status.value.message) return diff --git a/platypush/backend/http/app/utils/auth/__init__.py b/platypush/backend/http/app/utils/auth/__init__.py index 7d1b5ef65c..6a4231676d 100644 --- a/platypush/backend/http/app/utils/auth/__init__.py +++ b/platypush/backend/http/app/utils/auth/__init__.py @@ -2,14 +2,14 @@ import base64 from functools import wraps from typing import Optional -from flask import request, redirect, jsonify +from flask import request, redirect from flask.wrappers import Response from platypush.config import Config from platypush.user import UserManager from ..logger import logger -from .status import AuthStatus +from .status import UserAuthStatus user_manager = UserManager() @@ -128,18 +128,18 @@ def authenticate( skip_auth_methods=skip_auth_methods, ) - if auth_status == AuthStatus.OK: + if auth_status == UserAuthStatus.OK: return f(*args, **kwargs) if json: - return jsonify(auth_status.to_dict()), auth_status.value.code + return auth_status.to_response() - if auth_status == AuthStatus.NO_USERS: + if auth_status == UserAuthStatus.REGISTRATION_REQUIRED: return redirect( f'/register?redirect={redirect_page or request.url}', 307 ) - if auth_status == AuthStatus.UNAUTHORIZED: + if auth_status == UserAuthStatus.INVALID_CREDENTIALS: return redirect(f'/login?redirect={redirect_page or request.url}', 307) return Response( @@ -154,7 +154,7 @@ def authenticate( # pylint: disable=too-many-return-statements -def get_auth_status(req, skip_auth_methods=None) -> AuthStatus: +def get_auth_status(req, skip_auth_methods=None) -> UserAuthStatus: """ Check against the available authentication methods (except those listed in ``skip_auth_methods``) if the user is properly authenticated. @@ -168,29 +168,33 @@ def get_auth_status(req, skip_auth_methods=None) -> AuthStatus: if n_users > 0 and 'http' not in skip_methods: http_auth_ok = authenticate_user_pass(req) if http_auth_ok: - return AuthStatus.OK + return UserAuthStatus.OK # Token-based authentication token_auth_ok = True if 'token' not in skip_methods: token_auth_ok = authenticate_token(req) if token_auth_ok: - return AuthStatus.OK + return UserAuthStatus.OK # Session token based authentication session_auth_ok = True if n_users > 0 and 'session' not in skip_methods: - return AuthStatus.OK if authenticate_session(req) else AuthStatus.UNAUTHORIZED + return ( + UserAuthStatus.OK + if authenticate_session(req) + else UserAuthStatus.INVALID_CREDENTIALS + ) # At least a user should be created before accessing an authenticated resource if n_users == 0 and 'session' not in skip_methods: - return AuthStatus.NO_USERS + return UserAuthStatus.REGISTRATION_REQUIRED if ( # pylint: disable=too-many-boolean-expressions ('http' not in skip_methods and http_auth_ok) or ('token' not in skip_methods and token_auth_ok) or ('session' not in skip_methods and session_auth_ok) ): - return AuthStatus.OK + return UserAuthStatus.OK - return AuthStatus.UNAUTHORIZED + return UserAuthStatus.INVALID_CREDENTIALS diff --git a/platypush/backend/http/app/utils/auth/status.py b/platypush/backend/http/app/utils/auth/status.py index ee64fb51d9..5c3d0d75a6 100644 --- a/platypush/backend/http/app/utils/auth/status.py +++ b/platypush/backend/http/app/utils/auth/status.py @@ -1,21 +1,67 @@ from collections import namedtuple from enum import Enum +from flask import jsonify -StatusValue = namedtuple('StatusValue', ['code', 'message']) +from platypush.user import AuthenticationStatus + +StatusValue = namedtuple('StatusValue', ['code', 'error', 'message']) -class AuthStatus(Enum): +class UserAuthStatus(Enum): """ Models the status of the authentication. """ - OK = StatusValue(200, 'OK') - UNAUTHORIZED = StatusValue(401, 'Unauthorized') - NO_USERS = StatusValue(412, 'Please create a user first') + OK = StatusValue(200, AuthenticationStatus.OK, 'OK') + INVALID_AUTH_TYPE = StatusValue( + 400, AuthenticationStatus.INVALID_AUTH_TYPE, 'Invalid authentication type' + ) + INVALID_CREDENTIALS = StatusValue( + 401, AuthenticationStatus.INVALID_CREDENTIALS, 'Invalid credentials' + ) + INVALID_JWT_TOKEN = StatusValue( + 401, AuthenticationStatus.INVALID_JWT_TOKEN, 'Invalid JWT token' + ) + INVALID_OTP_CODE = StatusValue( + 401, AuthenticationStatus.INVALID_OTP_CODE, 'Invalid OTP code' + ) + INVALID_METHOD = StatusValue( + 405, AuthenticationStatus.INVALID_METHOD, 'Invalid method' + ) + MISSING_OTP_CODE = StatusValue( + 401, AuthenticationStatus.MISSING_OTP_CODE, 'Missing OTP code' + ) + MISSING_PASSWORD = StatusValue( + 400, AuthenticationStatus.MISSING_PASSWORD, 'Missing password' + ) + MISSING_USERNAME = StatusValue( + 400, AuthenticationStatus.MISSING_USERNAME, 'Missing username' + ) + PASSWORD_MISMATCH = StatusValue( + 400, AuthenticationStatus.PASSWORD_MISMATCH, 'Password mismatch' + ) + REGISTRATION_DISABLED = StatusValue( + 401, AuthenticationStatus.REGISTRATION_DISABLED, 'Registrations are disabled' + ) + REGISTRATION_REQUIRED = StatusValue( + 412, AuthenticationStatus.REGISTRATION_REQUIRED, 'Please create a user first' + ) def to_dict(self): return { 'code': self.value[0], - 'message': self.value[1], + 'error': self.value[1].name, + 'message': self.value[2], } + + def to_response(self): + return jsonify(self.to_dict()), self.value[0] + + @staticmethod + def by_status(status: AuthenticationStatus): + for auth_status in UserAuthStatus: + if auth_status.value[1] == status: + return auth_status + + return None diff --git a/platypush/backend/http/app/ws/_base.py b/platypush/backend/http/app/ws/_base.py index 3f2ba3d08b..4cdfee4efe 100644 --- a/platypush/backend/http/app/ws/_base.py +++ b/platypush/backend/http/app/ws/_base.py @@ -5,7 +5,7 @@ from threading import Thread from tornado.ioloop import IOLoop from tornado.websocket import WebSocketHandler -from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status +from platypush.backend.http.app.utils.auth import UserAuthStatus, get_auth_status from ..mixins import MessageType, PubSubMixin @@ -25,7 +25,7 @@ class WSRoute(WebSocketHandler, Thread, PubSubMixin, ABC): def open(self, *_, **__): auth_status = get_auth_status(self.request) - if auth_status != AuthStatus.OK: + if auth_status != UserAuthStatus.OK: self.close(code=1008, reason=auth_status.value.message) # Policy Violation return diff --git a/platypush/backend/http/webapp/src/App.vue b/platypush/backend/http/webapp/src/App.vue index e344c906be..004394f17f 100644 --- a/platypush/backend/http/webapp/src/App.vue +++ b/platypush/backend/http/webapp/src/App.vue @@ -1,20 +1,30 @@ @@ -116,7 +210,7 @@ form { width: 100%; } - input[type=submit], + [type=submit], input[type=password] { border-radius: 1em; } @@ -133,10 +227,33 @@ form { .buttons { text-align: center; - input[type=submit] { + [type=submit] { + position: relative; + width: 6em; + height: 2.5em; padding: .5em .75em; + display: inline-flex; + align-items: center; + justify-content: center; + + &.loading { + background: none; + border: none; + cursor: not-allowed; + } } } + + .auth-error { + background: $error-bg; + display: flex; + margin: 1em 0 -2em 0; + padding: .5em; + align-items: center; + justify-content: center; + border: $notification-error-border; + border-radius: 1em; + } } a { diff --git a/platypush/backend/http/webapp/vue.config.js b/platypush/backend/http/webapp/vue.config.js index 1ff73b8cdf..a55d56d608 100644 --- a/platypush/backend/http/webapp/vue.config.js +++ b/platypush/backend/http/webapp/vue.config.js @@ -34,16 +34,17 @@ module.exports = { devServer: { proxy: { + '^/auth': httpProxy, + '^/camera/': httpProxy, + '^/execute': httpProxy, + '^/file': httpProxy, + '^/logo.svg': httpProxy, + '^/logout': httpProxy, + '^/media/': httpProxy, + '^/sound/': httpProxy, '^/ws/events': wsProxy, '^/ws/requests': wsProxy, '^/ws/shell': wsProxy, - '^/execute': httpProxy, - '^/file': httpProxy, - '^/auth': httpProxy, - '^/camera/': httpProxy, - '^/sound/': httpProxy, - '^/media/': httpProxy, - '^/logo.svg': httpProxy, } } }; diff --git a/platypush/plugins/user/__init__.py b/platypush/plugins/user/__init__.py index 59ed283731..8d29afe046 100644 --- a/platypush/plugins/user/__init__.py +++ b/platypush/plugins/user/__init__.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional, Tuple, Union from platypush.plugins import Plugin, action from platypush.user import UserManager @@ -66,15 +66,46 @@ class UserPlugin(Plugin): } @action - def authenticate_user(self, username, password): + def authenticate_user( + self, + username: str, + password: str, + code: Optional[str] = None, + return_details: bool = False, + ) -> Union[bool, Tuple[bool, str]]: """ Authenticate a user. - :return: True if the provided username and password are correct, False - otherwise. - """ + :param username: Username. + :param password: Password. + :param code: Optional 2FA code, if 2FA is enabled for the user. + :param return_error_details: If True then return the error details in + case of authentication failure. + :return: If ``return_details`` is False (default), the action returns + True if the provided credentials are valid, False otherwise. + If ``return_details`` is True then the action returns a tuple + (authenticated, error_details) where ``authenticated`` is True if + the provided credentials are valid, False otherwise, and + ``error_details`` is a string containing the error details in case + of authentication failure. Supported error details are: - return bool(self.user_manager.authenticate_user(username, password)) + - ``invalid_credentials``: Invalid username or password. + - ``invalid_otp_code``: Invalid 2FA code. + - ``missing_otp_code``: Username/password are correct, but a 2FA + code is required for the user. + + """ + response = self.user_manager.authenticate_user( + username, password, code=code, return_error=return_details + ) + + if return_details: + assert ( + isinstance(response, tuple) and len(response) == 2 + ), 'Invalid response from authenticate_user' + return response[0], response[1].value + + return response @action def update_password(self, username, old_password, new_password): @@ -111,7 +142,7 @@ class UserPlugin(Plugin): return None, "No such user: {}".format(username) @action - def create_session(self, username, password, expires_at=None): + def create_session(self, username, password, code=None, expires_at=None): """ Create a user session. @@ -130,7 +161,7 @@ class UserPlugin(Plugin): """ session = self.user_manager.create_user_session( - username=username, password=password, expires_at=expires_at + username=username, password=password, code=code, expires_at=expires_at ) if not session: @@ -140,9 +171,9 @@ class UserPlugin(Plugin): 'session_token': session.session_token, 'user_id': session.user_id, 'created_at': session.created_at.isoformat(), - 'expires_at': session.expires_at.isoformat() - if session.expires_at - else None, + 'expires_at': ( + session.expires_at.isoformat() if session.expires_at else None # type: ignore + ), } @action diff --git a/platypush/user/__init__.py b/platypush/user/__init__.py index a6db904bd3..66f5455007 100644 --- a/platypush/user/__init__.py +++ b/platypush/user/__init__.py @@ -1,11 +1,12 @@ import base64 import datetime +import enum import hashlib import json import os import random import time -from typing import Optional, Dict +from typing import List, Optional, Dict, Tuple, Union import rsa @@ -13,12 +14,32 @@ from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy.orm import make_transient from platypush.common.db import Base +from platypush.config import Config from platypush.context import get_plugin from platypush.exceptions.user import ( InvalidJWTTokenException, InvalidCredentialsException, ) -from platypush.utils import get_or_generate_jwt_rsa_key_pair, utcnow +from platypush.utils import get_or_generate_stored_rsa_key_pair, utcnow + + +class AuthenticationStatus(enum.Enum): + """ + Enum for authentication errors. + """ + + OK = '' + INVALID_AUTH_TYPE = 'invalid_auth_type' + INVALID_CREDENTIALS = 'invalid_credentials' + INVALID_METHOD = 'invalid_method' + INVALID_JWT_TOKEN = 'invalid_jwt_token' + INVALID_OTP_CODE = 'invalid_otp_code' + MISSING_OTP_CODE = 'missing_otp_code' + MISSING_PASSWORD = 'missing_password' + MISSING_USERNAME = 'missing_username' + PASSWORD_MISMATCH = 'password_mismatch' + REGISTRATION_DISABLED = 'registration_disabled' + REGISTRATION_REQUIRED = 'registration_required' class UserManager: @@ -26,6 +47,14 @@ class UserManager: Main class for managing platform users """ + _otp_workdir = os.path.join(Config.get_workdir(), 'otp') + _otp_keyfile = os.path.join(_otp_workdir, 'key') + _otp_keyfile_pub = f'{_otp_keyfile}.pub' + + _jwt_workdir = os.path.join(Config.get_workdir(), 'jwt') + _jwt_keyfile = os.path.join(_jwt_workdir, 'id_rsa') + _jwt_keyfile_pub = f'{_jwt_keyfile}.pub' + def __init__(self): db_plugin = get_plugin('db') assert db_plugin, 'Database plugin not configured' @@ -41,6 +70,44 @@ class UserManager: def _get_session(self, *args, **kwargs): return self.db.get_session(self.db.get_engine(), *args, **kwargs) + @classmethod + def _get_jwt_rsa_key_pair(cls): + """ + Get or generate the JWT RSA key pair. + """ + return get_or_generate_stored_rsa_key_pair(cls._jwt_keyfile, size=2048) + + @classmethod + def _get_or_generate_otp_rsa_key_pair(cls): + """ + Get or generate the OTP RSA key pair. + """ + return get_or_generate_stored_rsa_key_pair(cls._otp_keyfile, size=4096) + + @staticmethod + def _encrypt(data: Union[str, bytes, dict, list, tuple], key: rsa.PublicKey) -> str: + """ + Encrypt the data using the given RSA public key. + """ + if isinstance(data, tuple): + data = list(data) + if isinstance(data, (dict, list)): + data = json.dumps(data, sort_keys=True, indent=None) + if isinstance(data, str): + data = data.encode('ascii') + + return base64.b64encode(rsa.encrypt(data, key)).decode() + + @staticmethod + def _decrypt(data: Union[str, bytes], key: rsa.PrivateKey) -> str: + """ + Decrypt the data using the given RSA private key. + """ + if isinstance(data, str): + data = data.encode('ascii') + + return rsa.decrypt(base64.b64decode(data), key).decode() + def get_user(self, username): with self._get_session() as session: user = self._get_user(session, username) @@ -88,9 +155,9 @@ class UserManager: return self._mask_password(user) - def update_password(self, username, old_password, new_password): + def update_password(self, username, old_password, new_password, code=None): with self._get_session(locked=True) as session: - if not self._authenticate_user(session, username, old_password): + if not self._authenticate_user(session, username, old_password, code=code): return False user = self._get_user(session, username) @@ -103,12 +170,22 @@ class UserManager: session.commit() return True - def authenticate_user(self, username, password): + def authenticate_user(self, username, password, code=None, return_error=False): with self._get_session() as session: - return self._authenticate_user(session, username, password) + return self._authenticate_user( + session, username, password, code=code, return_error=return_error + ) - def authenticate_user_session(self, session_token): + def authenticate_user_session(self, session_token, with_error=False): with self._get_session() as session: + users_count = session.query(User).count() + if not users_count: + return ( + (None, None, AuthenticationStatus.REGISTRATION_REQUIRED) + if with_error + else (None, None) + ) + user_session = ( session.query(UserSession) .filter_by(session_token=session_token) @@ -122,10 +199,21 @@ class UserManager: ) if not user_session or (expires_at and expires_at < utcnow()): - return None, None + return ( + (None, None, AuthenticationStatus.INVALID_CREDENTIALS) + if with_error + else (None, None) + ) user = session.query(User).filter_by(user_id=user_session.user_id).first() - return self._mask_password(user), user_session + return ( + (self._mask_password(user), user_session, AuthenticationStatus.OK) + if with_error + else ( + self._mask_password(user), + user_session, + ) + ) def delete_user(self, username): with self._get_session(locked=True) as session: @@ -158,11 +246,25 @@ class UserManager: session.commit() return True - def create_user_session(self, username, password, expires_at=None): + def create_user_session( + self, + username, + password, + code=None, + expires_at=None, + error_on_invalid_credentials=False, + ): with self._get_session(locked=True) as session: - user = self._authenticate_user(session, username, password) + user, status = self._authenticate_user( # type: ignore + session, + username, + password, + code=code, + return_error=error_on_invalid_credentials, + ) + if not user: - return None + return None if not error_on_invalid_credentials else (None, status) if expires_at: if isinstance(expires_at, (int, float)): @@ -180,7 +282,53 @@ class UserManager: session.add(user_session) session.commit() - return user_session + return user_session, ( + AuthenticationStatus.OK if not error_on_invalid_credentials else status + ) + + def create_otp_secret( + self, username: str, expires_at: Optional[datetime.datetime] = None + ): + pubkey, _ = self._get_or_generate_otp_rsa_key_pair() + + # Generate a new OTP secret and encrypt it with the OTP RSA key pair + otp_secret = "".join( + random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(32) + ) + + encrypted_secret = self._encrypt(otp_secret, pubkey) + + with self._get_session(locked=True) as session: + user = self._get_user(session, username) + assert user, f'No such user: {username}' + + # Create a new OTP secret + user_otp = UserOtp( + user_id=user.user_id, + otp_secret=encrypted_secret, + created_at=utcnow(), + expires_at=expires_at, + ) + + # Remove any existing OTP secret and replace it with the new one + session.query(UserOtp).filter_by(user_id=user.user_id).delete() + session.add(user_otp) + session.commit() + + return user_otp + + def get_otp_secret(self, username: str) -> Optional[str]: + with self._get_session() as session: + user = self._get_user(session, username) + if not user: + return None + + user_otp = session.query(UserOtp).filter_by(user_id=user.user_id).first() + if not user_otp: + return None + + _, priv_key = self._get_or_generate_otp_rsa_key_pair() + return self._decrypt(user_otp.otp_secret, priv_key) @staticmethod def _get_user(session, username): @@ -268,20 +416,17 @@ class UserManager: if not user: raise InvalidCredentialsException() - pub_key, _ = get_or_generate_jwt_rsa_key_pair() - payload = json.dumps( + pub_key, _ = self._get_jwt_rsa_key_pair() + return self._encrypt( { 'username': username, 'password': password, 'created_at': datetime.datetime.now().timestamp(), 'expires_at': expires_at.timestamp() if expires_at else None, }, - sort_keys=True, - indent=None, + pub_key, ) - return base64.b64encode(rsa.encrypt(payload.encode('ascii'), pub_key)).decode() - def validate_jwt_token(self, token: str) -> Dict[str, str]: """ Validate a JWT token. @@ -299,14 +444,10 @@ class UserManager: :raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token. """ - _, priv_key = get_or_generate_jwt_rsa_key_pair() + _, priv_key = self._get_jwt_rsa_key_pair() try: - payload = json.loads( - rsa.decrypt(base64.b64decode(token.encode('ascii')), priv_key).decode( - 'ascii' - ) - ) + payload = json.loads(self._decrypt(token, priv_key)) except (TypeError, ValueError) as e: raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e @@ -323,23 +464,160 @@ class UserManager: return payload - def _authenticate_user(self, session, username, password): + def _authenticate_user( + self, + session, + username: str, + password: str, + code: Optional[str] = None, + return_error: bool = False, + ) -> Union[Optional['User'], Tuple[Optional['User'], 'AuthenticationStatus']]: """ - :return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise. + :return: :class:`platypush.user.User` instance if the user exists and + the password is valid, ``None`` otherwise. """ user = self._get_user(session, username) - if not user: - return None + # The user does not exist + if not user: + return ( + None + if not return_error + else (None, AuthenticationStatus.INVALID_CREDENTIALS) + ) + + # The password is not correct if not self._check_password( password, user.password, bytes.fromhex(user.password_salt) if user.password_salt else None, user.hmac_iterations, ): - return None + return ( + None + if not return_error + else (None, AuthenticationStatus.INVALID_CREDENTIALS) + ) - return user + otp_secret = self.get_otp_secret(username) + + # The user doesn't have 2FA enabled and the password is correct: + # authentication successful + if not otp_secret: + return user if not return_error else (user, AuthenticationStatus.OK) + + # The user has 2FA enabled but the code is missing + if not code: + return ( + None + if not return_error + else (None, AuthenticationStatus.MISSING_OTP_CODE) + ) + + if self.validate_otp_code(username, code): + return user if not return_error else (user, AuthenticationStatus.OK) + + if not self.validate_backup_code(username, code): + return ( + None + if not return_error + else (None, AuthenticationStatus.INVALID_OTP_CODE) + ) + + return user if not return_error else (user, AuthenticationStatus.OK) + + def refresh_user_backup_codes(self, username: str): + """ + Refresh the backup codes for a user with 2FA enabled. + """ + with self._get_session(locked=True) as session: + user = self._get_user(session, username) + if not user: + return False + + session.query(UserBackupCode).filter_by(user_id=user.user_id).delete() + pub_key, _ = self._get_or_generate_otp_rsa_key_pair() + + for _ in range(10): + backup_code = "".join( + random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(16) + ) + + user_backup_code = UserBackupCode( + user_id=user.user_id, + code=self._encrypt(backup_code, pub_key), + created_at=utcnow(), + expires_at=utcnow() + datetime.timedelta(days=30), + ) + + session.add(user_backup_code) + + session.commit() + return True + + def get_user_backup_codes(self, username: str) -> List['UserBackupCode']: + with self._get_session() as session: + user = self._get_user(session, username) + if not user: + return [] + + _, priv_key = self._get_or_generate_otp_rsa_key_pair() + codes = session.query(UserBackupCode).filter_by(user_id=user.user_id).all() + + for code in codes: + code.code = self._decrypt(code.code, priv_key) + + return codes + + def validate_backup_code(self, username: str, code: str) -> bool: + with self._get_session() as session: + user = self._get_user(session, username) + if not user: + return False + + pub_key, _ = self._get_or_generate_otp_rsa_key_pair() + user_backup_code = ( + session.query(UserBackupCode) + .filter_by(user_id=user.user_id, code=self._encrypt(code, pub_key)) + .first() + ) + + if not user_backup_code: + return False + + session.delete(user_backup_code) + session.commit() + + return True + + def validate_otp_code(self, username: str, code: str) -> bool: + otp = get_plugin('otp') + assert otp + + with self._get_session() as session: + user = self._get_user(session, username) + if not user: + return False + + otp_secret = self.get_otp_secret(username) + if not otp_secret: + return False + + _, priv_key = self._get_or_generate_otp_rsa_key_pair() + otp_secret = self._decrypt(otp_secret, priv_key) + + return otp.verify_time_otp(otp=code, secret=otp_secret) + + def disable_mfa(self, username: str): + with self._get_session(locked=True) as session: + user = self._get_user(session, username) + if not user: + return False + + session.query(UserOtp).filter_by(user_id=user.user_id).delete() + session.query(UserBackupCode).filter_by(user_id=user.user_id).delete() + session.commit() + return True class User(Base): @@ -370,4 +648,31 @@ class UserSession(Base): expires_at = Column(DateTime) +class UserOtp(Base): + """ + Models the UserOtp table, which contains the OTP secrets for each user. + """ + + __tablename__ = 'user_otp' + + user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True) + otp_secret = Column(String, nullable=False, unique=True) + created_at = Column(DateTime) + expires_at = Column(DateTime) + + +class UserBackupCode(Base): + """ + Models the UserBackupCode table, which contains the backup codes for each + user with 2FA enabled. + """ + + __tablename__ = 'user_backup_code' + + user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True) + code = Column(String, nullable=False, unique=True) + created_at = Column(DateTime) + expires_at = Column(DateTime) + + # vim:sw=4:ts=4:et: diff --git a/platypush/utils/__init__.py b/platypush/utils/__init__.py index 5f869e3d74..96c4e79811 100644 --- a/platypush/utils/__init__.py +++ b/platypush/utils/__init__.py @@ -532,7 +532,12 @@ def generate_rsa_key_pair( private_key_str = priv_key.save_pkcs1('PEM').decode() if key_file: + pathlib.Path(os.path.dirname(os.path.expanduser(key_file))).mkdir( + parents=True, exist_ok=True + ) + logger.info('Saving private key to %s', key_file) + with open(os.path.expanduser(key_file), 'w') as f1, open( os.path.expanduser(key_file) + '.pub', 'w' ) as f2: @@ -543,14 +548,20 @@ def generate_rsa_key_pair( return pub_key, priv_key -def get_or_generate_jwt_rsa_key_pair(): +def get_or_generate_stored_rsa_key_pair( + keyfile: str, size: int = 2048 +) -> Tuple[PublicKey, PrivateKey]: """ - Get or generate a JWT RSA key pair. - """ - from platypush.config import Config + Get or generate an RSA key pair and store it in the given key file. - key_dir = os.path.join(Config.get_workdir(), 'jwt') - priv_key_file = os.path.join(key_dir, 'id_rsa') + The private key will be stored in the given file, while the public key will + be stored in ``.pub``. + + :param keyfile: Path to the key file. + :param size: Key size in bits (default: 2048). + """ + keydir = os.path.dirname(os.path.expanduser(keyfile)) + priv_key_file = os.path.join(keydir, os.path.basename(keyfile)) pub_key_file = priv_key_file + '.pub' if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file): @@ -560,8 +571,15 @@ def get_or_generate_jwt_rsa_key_pair(): PrivateKey.load_pkcs1(f2.read().encode()), ) - pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755) - return generate_rsa_key_pair(priv_key_file, size=2048) + pub_key, priv_key = generate_rsa_key_pair(priv_key_file, size=size) + pathlib.Path(keydir).mkdir(parents=True, exist_ok=True, mode=0o755) + + with open(pub_key_file, 'w') as f1, open(priv_key_file, 'w') as f2: + f1.write(pub_key.save_pkcs1('PEM').decode()) + f2.write(priv_key.save_pkcs1('PEM').decode()) + os.chmod(priv_key_file, 0o600) + + return pub_key, priv_key def get_enabled_plugins() -> dict: