Compare commits
2 commits
fe2497577d
...
ee27b2c4c6
Author | SHA1 | Date | |
---|---|---|---|
ee27b2c4c6 | |||
8904e40f9f |
15 changed files with 843 additions and 266 deletions
|
@ -4,8 +4,10 @@ import logging
|
||||||
|
|
||||||
from flask import Blueprint, request, abort, jsonify
|
from flask import Blueprint, request, abort, jsonify
|
||||||
|
|
||||||
|
from platypush.backend.http.app.utils.auth import UserAuthStatus
|
||||||
from platypush.exceptions.user import UserException
|
from platypush.exceptions.user import UserException
|
||||||
from platypush.user import UserManager
|
from platypush.user import UserManager
|
||||||
|
from platypush.utils import utcnow
|
||||||
|
|
||||||
auth = Blueprint('auth', __name__)
|
auth = Blueprint('auth', __name__)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -16,39 +18,24 @@ __routes__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@auth.route('/auth', methods=['POST'])
|
def _dump_session(session, redirect_page='/'):
|
||||||
def auth_endpoint():
|
return jsonify(
|
||||||
"""
|
|
||||||
Authentication endpoint. It validates the user credentials provided over a JSON payload with the following
|
|
||||||
structure:
|
|
||||||
|
|
||||||
.. code-block:: json
|
|
||||||
|
|
||||||
{
|
{
|
||||||
"username": "USERNAME",
|
'status': 'ok',
|
||||||
"password": "PASSWORD",
|
'user_id': session.user_id,
|
||||||
"expiry_days": "The generated token should be valid for these many days"
|
'session_token': session.session_token,
|
||||||
|
'expires_at': session.expires_at,
|
||||||
|
'redirect': redirect_page,
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
``expiry_days`` is optional, and if omitted or set to zero the token will be valid indefinitely.
|
|
||||||
|
|
||||||
Upon successful validation, a new JWT token will be generated using the service's self-generated RSA key-pair and it
|
def _jwt_auth():
|
||||||
will be returned to the user. The token can then be used to authenticate API calls to ``/execute`` by setting the
|
|
||||||
``Authorization: Bearer <TOKEN_HERE>`` header upon HTTP calls.
|
|
||||||
|
|
||||||
:return: Return structure:
|
|
||||||
|
|
||||||
.. code-block:: json
|
|
||||||
|
|
||||||
{
|
|
||||||
"token": "<generated token here>"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
payload = json.loads(request.get_data(as_text=True))
|
payload = json.loads(request.get_data(as_text=True))
|
||||||
username, password = payload['username'], payload['password']
|
username, password = payload['username'], payload['password']
|
||||||
except Exception as e:
|
except Exception:
|
||||||
log.warning('Invalid payload passed to the auth endpoint: ' + str(e))
|
log.warning('Invalid payload passed to the auth endpoint')
|
||||||
abort(400)
|
abort(400)
|
||||||
|
|
||||||
expiry_days = payload.get('expiry_days')
|
expiry_days = payload.get('expiry_days')
|
||||||
|
@ -59,8 +46,174 @@ def auth_endpoint():
|
||||||
user_manager = UserManager()
|
user_manager = UserManager()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return jsonify({
|
return jsonify(
|
||||||
'token': user_manager.generate_jwt_token(username=username, password=password, expires_at=expires_at),
|
{
|
||||||
})
|
'token': user_manager.generate_jwt_token(
|
||||||
|
username=username, password=password, expires_at=expires_at
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
except UserException as e:
|
except UserException as e:
|
||||||
abort(401, str(e))
|
abort(401, str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def _session_auth():
|
||||||
|
user_manager = UserManager()
|
||||||
|
session_token = request.cookies.get('session_token')
|
||||||
|
redirect_page = request.args.get('redirect') or '/'
|
||||||
|
|
||||||
|
if session_token:
|
||||||
|
user, session = user_manager.authenticate_user_session(session_token) # type: ignore
|
||||||
|
if user and session:
|
||||||
|
return _dump_session(session, redirect_page)
|
||||||
|
|
||||||
|
if request.form:
|
||||||
|
username = request.form.get('username')
|
||||||
|
password = request.form.get('password')
|
||||||
|
code = request.form.get('code')
|
||||||
|
remember = request.form.get('remember')
|
||||||
|
expires = utcnow() + datetime.timedelta(days=365) if remember else None
|
||||||
|
session, status = user_manager.create_user_session( # type: ignore
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
code=code,
|
||||||
|
expires_at=expires,
|
||||||
|
error_on_invalid_credentials=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if session:
|
||||||
|
return _dump_session(session, redirect_page)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
return status.to_response() # type: ignore
|
||||||
|
|
||||||
|
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
|
||||||
|
|
||||||
|
|
||||||
|
def _register_route():
|
||||||
|
"""Registration endpoint"""
|
||||||
|
user_manager = UserManager()
|
||||||
|
session_token = request.cookies.get('session_token')
|
||||||
|
redirect_page = request.args.get('redirect') or '/'
|
||||||
|
|
||||||
|
if session_token:
|
||||||
|
user, session = user_manager.authenticate_user_session(session_token) # type: ignore
|
||||||
|
if user and session:
|
||||||
|
return _dump_session(session, redirect_page)
|
||||||
|
|
||||||
|
if user_manager.get_user_count() > 0:
|
||||||
|
return UserAuthStatus.REGISTRATION_DISABLED.to_response()
|
||||||
|
|
||||||
|
if not request.form:
|
||||||
|
return UserAuthStatus.MISSING_USERNAME.to_response()
|
||||||
|
|
||||||
|
username = request.form.get('username')
|
||||||
|
password = request.form.get('password')
|
||||||
|
confirm_password = request.form.get('confirm_password')
|
||||||
|
remember = request.form.get('remember')
|
||||||
|
|
||||||
|
if not username:
|
||||||
|
return UserAuthStatus.MISSING_USERNAME.to_response()
|
||||||
|
if not password:
|
||||||
|
return UserAuthStatus.MISSING_PASSWORD.to_response()
|
||||||
|
if password != confirm_password:
|
||||||
|
return UserAuthStatus.PASSWORD_MISMATCH.to_response()
|
||||||
|
|
||||||
|
user_manager.create_user(username=username, password=password)
|
||||||
|
session, status = user_manager.create_user_session( # type: ignore
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
expires_at=(utcnow() + datetime.timedelta(days=365) if remember else None),
|
||||||
|
error_on_invalid_credentials=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if session:
|
||||||
|
return _dump_session(session, redirect_page)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
return status.to_response() # type: ignore
|
||||||
|
|
||||||
|
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_get():
|
||||||
|
"""
|
||||||
|
Get the current authentication status of the user session.
|
||||||
|
"""
|
||||||
|
user_manager = UserManager()
|
||||||
|
session_token = request.cookies.get('session_token')
|
||||||
|
redirect_page = request.args.get('redirect') or '/'
|
||||||
|
user, session, status = user_manager.authenticate_user_session( # type: ignore
|
||||||
|
session_token, with_error=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if user and session:
|
||||||
|
return _dump_session(session, redirect_page)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
return UserAuthStatus.by_status(status).to_response() # type: ignore
|
||||||
|
|
||||||
|
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_post():
|
||||||
|
"""
|
||||||
|
Authenticate the user session.
|
||||||
|
"""
|
||||||
|
auth_type = request.args.get('type') or 'jwt'
|
||||||
|
|
||||||
|
if auth_type == 'jwt':
|
||||||
|
return _jwt_auth()
|
||||||
|
|
||||||
|
if auth_type == 'register':
|
||||||
|
return _register_route()
|
||||||
|
|
||||||
|
if auth_type == 'login':
|
||||||
|
return _session_auth()
|
||||||
|
|
||||||
|
return UserAuthStatus.INVALID_AUTH_TYPE.to_response()
|
||||||
|
|
||||||
|
|
||||||
|
@auth.route('/auth', methods=['GET', 'POST'])
|
||||||
|
def auth_endpoint():
|
||||||
|
"""
|
||||||
|
Authentication endpoint. It validates the user credentials provided over a
|
||||||
|
JSON payload with the following structure:
|
||||||
|
|
||||||
|
.. code-block:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"username": "USERNAME",
|
||||||
|
"password": "PASSWORD",
|
||||||
|
"code": "2FA_CODE",
|
||||||
|
"expiry_days": "The generated token should be valid for these many days"
|
||||||
|
}
|
||||||
|
|
||||||
|
``expiry_days`` is optional, and if omitted or set to zero the token will
|
||||||
|
be valid indefinitely.
|
||||||
|
|
||||||
|
Upon successful validation, a new JWT token will be generated using the
|
||||||
|
service's self-generated RSA key-pair and it will be returned to the user.
|
||||||
|
The token can then be used to authenticate API calls to ``/execute`` by
|
||||||
|
setting the ``Authorization: Bearer <TOKEN_HERE>`` header upon HTTP calls.
|
||||||
|
|
||||||
|
:return: Return structure:
|
||||||
|
|
||||||
|
.. code-block:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"token": "<generated token here>"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if request.method == 'GET':
|
||||||
|
return _auth_get()
|
||||||
|
|
||||||
|
if request.method == 'POST':
|
||||||
|
return _auth_post()
|
||||||
|
|
||||||
|
return UserAuthStatus.INVALID_METHOD.to_response()
|
||||||
|
|
||||||
|
|
||||||
|
# from flask import Blueprint, request, redirect, render_template, make_response, abort
|
||||||
|
|
||||||
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
|
@ -1,56 +0,0 @@
|
||||||
import datetime
|
|
||||||
import re
|
|
||||||
|
|
||||||
from flask import Blueprint, request, redirect, render_template, make_response
|
|
||||||
|
|
||||||
from platypush.backend.http.app import template_folder
|
|
||||||
from platypush.backend.http.utils import HttpUtils
|
|
||||||
from platypush.user import UserManager
|
|
||||||
from platypush.utils import utcnow
|
|
||||||
|
|
||||||
login = Blueprint('login', __name__, template_folder=template_folder)
|
|
||||||
|
|
||||||
# Declare routes list
|
|
||||||
__routes__ = [
|
|
||||||
login,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@login.route('/login', methods=['GET', 'POST'])
|
|
||||||
def login():
|
|
||||||
"""Login page"""
|
|
||||||
user_manager = UserManager()
|
|
||||||
session_token = request.cookies.get('session_token')
|
|
||||||
|
|
||||||
redirect_page = request.args.get('redirect')
|
|
||||||
if not redirect_page:
|
|
||||||
redirect_page = request.headers.get('Referer', '/')
|
|
||||||
if re.search('(^https?://[^/]+)?/login[^?#]?', redirect_page):
|
|
||||||
# Prevent redirect loop
|
|
||||||
redirect_page = '/'
|
|
||||||
|
|
||||||
if session_token:
|
|
||||||
user, session = user_manager.authenticate_user_session(session_token)
|
|
||||||
if user:
|
|
||||||
return redirect(redirect_page, 302) # lgtm [py/url-redirection]
|
|
||||||
|
|
||||||
if request.form:
|
|
||||||
username = request.form.get('username')
|
|
||||||
password = request.form.get('password')
|
|
||||||
remember = request.form.get('remember')
|
|
||||||
expires = utcnow() + datetime.timedelta(days=365) if remember else None
|
|
||||||
|
|
||||||
session = user_manager.create_user_session(
|
|
||||||
username=username, password=password, expires_at=expires
|
|
||||||
)
|
|
||||||
|
|
||||||
if session:
|
|
||||||
redirect_target = redirect(redirect_page, 302) # lgtm [py/url-redirection]
|
|
||||||
response = make_response(redirect_target)
|
|
||||||
response.set_cookie('session_token', session.session_token, expires=expires)
|
|
||||||
return response
|
|
||||||
|
|
||||||
return render_template('index.html', utils=HttpUtils)
|
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
|
|
@ -1,71 +0,0 @@
|
||||||
import datetime
|
|
||||||
import re
|
|
||||||
|
|
||||||
from flask import Blueprint, request, redirect, render_template, make_response, abort
|
|
||||||
|
|
||||||
from platypush.backend.http.app import template_folder
|
|
||||||
from platypush.backend.http.utils import HttpUtils
|
|
||||||
from platypush.user import UserManager
|
|
||||||
from platypush.utils import utcnow
|
|
||||||
|
|
||||||
register = Blueprint('register', __name__, template_folder=template_folder)
|
|
||||||
|
|
||||||
# Declare routes list
|
|
||||||
__routes__ = [
|
|
||||||
register,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@register.route('/register', methods=['GET', 'POST'])
|
|
||||||
def register():
|
|
||||||
"""Registration page"""
|
|
||||||
user_manager = UserManager()
|
|
||||||
redirect_page = request.args.get('redirect')
|
|
||||||
if not redirect_page:
|
|
||||||
redirect_page = request.headers.get('Referer', '/')
|
|
||||||
if re.search('(^https?://[^/]+)?/register[^?#]?', redirect_page):
|
|
||||||
# Prevent redirect loop
|
|
||||||
redirect_page = '/'
|
|
||||||
|
|
||||||
session_token = request.cookies.get('session_token')
|
|
||||||
|
|
||||||
if session_token:
|
|
||||||
user, session = user_manager.authenticate_user_session(session_token)
|
|
||||||
if user:
|
|
||||||
return redirect(redirect_page, 302) # lgtm [py/url-redirection]
|
|
||||||
|
|
||||||
if user_manager.get_user_count() > 0:
|
|
||||||
return redirect(
|
|
||||||
'/login?redirect=' + redirect_page, 302
|
|
||||||
) # lgtm [py/url-redirection]
|
|
||||||
|
|
||||||
if request.form:
|
|
||||||
username = request.form.get('username')
|
|
||||||
password = request.form.get('password')
|
|
||||||
confirm_password = request.form.get('confirm_password')
|
|
||||||
remember = request.form.get('remember')
|
|
||||||
|
|
||||||
if password == confirm_password:
|
|
||||||
user_manager.create_user(username=username, password=password)
|
|
||||||
session = user_manager.create_user_session(
|
|
||||||
username=username,
|
|
||||||
password=password,
|
|
||||||
expires_at=(
|
|
||||||
utcnow() + datetime.timedelta(days=1) if not remember else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if session:
|
|
||||||
redirect_target = redirect(
|
|
||||||
redirect_page, 302
|
|
||||||
) # lgtm [py/url-redirection]
|
|
||||||
response = make_response(redirect_target)
|
|
||||||
response.set_cookie('session_token', session.session_token)
|
|
||||||
return response
|
|
||||||
else:
|
|
||||||
abort(400, 'Password mismatch')
|
|
||||||
|
|
||||||
return render_template('index.html', utils=HttpUtils)
|
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
|
|
@ -7,7 +7,7 @@ from typing import Optional
|
||||||
from tornado.web import RequestHandler, stream_request_body
|
from tornado.web import RequestHandler, stream_request_body
|
||||||
|
|
||||||
from platypush.backend.http.app.utils import logger
|
from platypush.backend.http.app.utils import logger
|
||||||
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
|
from platypush.backend.http.app.utils.auth import UserAuthStatus, get_auth_status
|
||||||
|
|
||||||
from ..mixins import PubSubMixin
|
from ..mixins import PubSubMixin
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class StreamingRoute(RequestHandler, PubSubMixin, ABC):
|
||||||
"""
|
"""
|
||||||
if self.auth_required:
|
if self.auth_required:
|
||||||
auth_status = get_auth_status(self.request)
|
auth_status = get_auth_status(self.request)
|
||||||
if auth_status != AuthStatus.OK:
|
if auth_status != UserAuthStatus.OK:
|
||||||
self.send_error(auth_status.value.code, error=auth_status.value.message)
|
self.send_error(auth_status.value.code, error=auth_status.value.message)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -2,14 +2,14 @@ import base64
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from flask import request, redirect, jsonify
|
from flask import request, redirect
|
||||||
from flask.wrappers import Response
|
from flask.wrappers import Response
|
||||||
|
|
||||||
from platypush.config import Config
|
from platypush.config import Config
|
||||||
from platypush.user import UserManager
|
from platypush.user import UserManager
|
||||||
|
|
||||||
from ..logger import logger
|
from ..logger import logger
|
||||||
from .status import AuthStatus
|
from .status import UserAuthStatus
|
||||||
|
|
||||||
user_manager = UserManager()
|
user_manager = UserManager()
|
||||||
|
|
||||||
|
@ -128,18 +128,18 @@ def authenticate(
|
||||||
skip_auth_methods=skip_auth_methods,
|
skip_auth_methods=skip_auth_methods,
|
||||||
)
|
)
|
||||||
|
|
||||||
if auth_status == AuthStatus.OK:
|
if auth_status == UserAuthStatus.OK:
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
if json:
|
if json:
|
||||||
return jsonify(auth_status.to_dict()), auth_status.value.code
|
return auth_status.to_response()
|
||||||
|
|
||||||
if auth_status == AuthStatus.NO_USERS:
|
if auth_status == UserAuthStatus.REGISTRATION_REQUIRED:
|
||||||
return redirect(
|
return redirect(
|
||||||
f'/register?redirect={redirect_page or request.url}', 307
|
f'/register?redirect={redirect_page or request.url}', 307
|
||||||
)
|
)
|
||||||
|
|
||||||
if auth_status == AuthStatus.UNAUTHORIZED:
|
if auth_status == UserAuthStatus.INVALID_CREDENTIALS:
|
||||||
return redirect(f'/login?redirect={redirect_page or request.url}', 307)
|
return redirect(f'/login?redirect={redirect_page or request.url}', 307)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
|
@ -154,7 +154,7 @@ def authenticate(
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-return-statements
|
# pylint: disable=too-many-return-statements
|
||||||
def get_auth_status(req, skip_auth_methods=None) -> AuthStatus:
|
def get_auth_status(req, skip_auth_methods=None) -> UserAuthStatus:
|
||||||
"""
|
"""
|
||||||
Check against the available authentication methods (except those listed in
|
Check against the available authentication methods (except those listed in
|
||||||
``skip_auth_methods``) if the user is properly authenticated.
|
``skip_auth_methods``) if the user is properly authenticated.
|
||||||
|
@ -168,29 +168,33 @@ def get_auth_status(req, skip_auth_methods=None) -> AuthStatus:
|
||||||
if n_users > 0 and 'http' not in skip_methods:
|
if n_users > 0 and 'http' not in skip_methods:
|
||||||
http_auth_ok = authenticate_user_pass(req)
|
http_auth_ok = authenticate_user_pass(req)
|
||||||
if http_auth_ok:
|
if http_auth_ok:
|
||||||
return AuthStatus.OK
|
return UserAuthStatus.OK
|
||||||
|
|
||||||
# Token-based authentication
|
# Token-based authentication
|
||||||
token_auth_ok = True
|
token_auth_ok = True
|
||||||
if 'token' not in skip_methods:
|
if 'token' not in skip_methods:
|
||||||
token_auth_ok = authenticate_token(req)
|
token_auth_ok = authenticate_token(req)
|
||||||
if token_auth_ok:
|
if token_auth_ok:
|
||||||
return AuthStatus.OK
|
return UserAuthStatus.OK
|
||||||
|
|
||||||
# Session token based authentication
|
# Session token based authentication
|
||||||
session_auth_ok = True
|
session_auth_ok = True
|
||||||
if n_users > 0 and 'session' not in skip_methods:
|
if n_users > 0 and 'session' not in skip_methods:
|
||||||
return AuthStatus.OK if authenticate_session(req) else AuthStatus.UNAUTHORIZED
|
return (
|
||||||
|
UserAuthStatus.OK
|
||||||
|
if authenticate_session(req)
|
||||||
|
else UserAuthStatus.INVALID_CREDENTIALS
|
||||||
|
)
|
||||||
|
|
||||||
# At least a user should be created before accessing an authenticated resource
|
# At least a user should be created before accessing an authenticated resource
|
||||||
if n_users == 0 and 'session' not in skip_methods:
|
if n_users == 0 and 'session' not in skip_methods:
|
||||||
return AuthStatus.NO_USERS
|
return UserAuthStatus.REGISTRATION_REQUIRED
|
||||||
|
|
||||||
if ( # pylint: disable=too-many-boolean-expressions
|
if ( # pylint: disable=too-many-boolean-expressions
|
||||||
('http' not in skip_methods and http_auth_ok)
|
('http' not in skip_methods and http_auth_ok)
|
||||||
or ('token' not in skip_methods and token_auth_ok)
|
or ('token' not in skip_methods and token_auth_ok)
|
||||||
or ('session' not in skip_methods and session_auth_ok)
|
or ('session' not in skip_methods and session_auth_ok)
|
||||||
):
|
):
|
||||||
return AuthStatus.OK
|
return UserAuthStatus.OK
|
||||||
|
|
||||||
return AuthStatus.UNAUTHORIZED
|
return UserAuthStatus.INVALID_CREDENTIALS
|
||||||
|
|
|
@ -1,21 +1,67 @@
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
from flask import jsonify
|
||||||
|
|
||||||
StatusValue = namedtuple('StatusValue', ['code', 'message'])
|
from platypush.user import AuthenticationStatus
|
||||||
|
|
||||||
|
StatusValue = namedtuple('StatusValue', ['code', 'error', 'message'])
|
||||||
|
|
||||||
|
|
||||||
class AuthStatus(Enum):
|
class UserAuthStatus(Enum):
|
||||||
"""
|
"""
|
||||||
Models the status of the authentication.
|
Models the status of the authentication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OK = StatusValue(200, 'OK')
|
OK = StatusValue(200, AuthenticationStatus.OK, 'OK')
|
||||||
UNAUTHORIZED = StatusValue(401, 'Unauthorized')
|
INVALID_AUTH_TYPE = StatusValue(
|
||||||
NO_USERS = StatusValue(412, 'Please create a user first')
|
400, AuthenticationStatus.INVALID_AUTH_TYPE, 'Invalid authentication type'
|
||||||
|
)
|
||||||
|
INVALID_CREDENTIALS = StatusValue(
|
||||||
|
401, AuthenticationStatus.INVALID_CREDENTIALS, 'Invalid credentials'
|
||||||
|
)
|
||||||
|
INVALID_JWT_TOKEN = StatusValue(
|
||||||
|
401, AuthenticationStatus.INVALID_JWT_TOKEN, 'Invalid JWT token'
|
||||||
|
)
|
||||||
|
INVALID_OTP_CODE = StatusValue(
|
||||||
|
401, AuthenticationStatus.INVALID_OTP_CODE, 'Invalid OTP code'
|
||||||
|
)
|
||||||
|
INVALID_METHOD = StatusValue(
|
||||||
|
405, AuthenticationStatus.INVALID_METHOD, 'Invalid method'
|
||||||
|
)
|
||||||
|
MISSING_OTP_CODE = StatusValue(
|
||||||
|
401, AuthenticationStatus.MISSING_OTP_CODE, 'Missing OTP code'
|
||||||
|
)
|
||||||
|
MISSING_PASSWORD = StatusValue(
|
||||||
|
400, AuthenticationStatus.MISSING_PASSWORD, 'Missing password'
|
||||||
|
)
|
||||||
|
MISSING_USERNAME = StatusValue(
|
||||||
|
400, AuthenticationStatus.MISSING_USERNAME, 'Missing username'
|
||||||
|
)
|
||||||
|
PASSWORD_MISMATCH = StatusValue(
|
||||||
|
400, AuthenticationStatus.PASSWORD_MISMATCH, 'Password mismatch'
|
||||||
|
)
|
||||||
|
REGISTRATION_DISABLED = StatusValue(
|
||||||
|
401, AuthenticationStatus.REGISTRATION_DISABLED, 'Registrations are disabled'
|
||||||
|
)
|
||||||
|
REGISTRATION_REQUIRED = StatusValue(
|
||||||
|
412, AuthenticationStatus.REGISTRATION_REQUIRED, 'Please create a user first'
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
'code': self.value[0],
|
'code': self.value[0],
|
||||||
'message': self.value[1],
|
'error': self.value[1].name,
|
||||||
|
'message': self.value[2],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def to_response(self):
|
||||||
|
return jsonify(self.to_dict()), self.value[0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def by_status(status: AuthenticationStatus):
|
||||||
|
for auth_status in UserAuthStatus:
|
||||||
|
if auth_status.value[1] == status:
|
||||||
|
return auth_status
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -5,7 +5,7 @@ from threading import Thread
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
from tornado.websocket import WebSocketHandler
|
from tornado.websocket import WebSocketHandler
|
||||||
|
|
||||||
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
|
from platypush.backend.http.app.utils.auth import UserAuthStatus, get_auth_status
|
||||||
|
|
||||||
from ..mixins import MessageType, PubSubMixin
|
from ..mixins import MessageType, PubSubMixin
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ class WSRoute(WebSocketHandler, Thread, PubSubMixin, ABC):
|
||||||
|
|
||||||
def open(self, *_, **__):
|
def open(self, *_, **__):
|
||||||
auth_status = get_auth_status(self.request)
|
auth_status = get_auth_status(self.request)
|
||||||
if auth_status != AuthStatus.OK:
|
if auth_status != UserAuthStatus.OK:
|
||||||
self.close(code=1008, reason=auth_status.value.message) # Policy Violation
|
self.close(code=1008, reason=auth_status.value.message) # Policy Violation
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,12 @@
|
||||||
<template>
|
<template>
|
||||||
|
<div id="error" v-if="initError">
|
||||||
|
<h1>Initialization error</h1>
|
||||||
|
<p>{{ initError }}</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Loading v-else-if="!initialized" />
|
||||||
|
|
||||||
|
<div id="app-container" v-else>
|
||||||
<Events ref="events" v-if="hasWebsocket" />
|
<Events ref="events" v-if="hasWebsocket" />
|
||||||
<Notifications ref="notifications" />
|
<Notifications ref="notifications" />
|
||||||
<VoiceAssistant ref="voice-assistant" v-if="hasAssistant" />
|
<VoiceAssistant ref="voice-assistant" v-if="hasAssistant" />
|
||||||
|
@ -10,11 +18,13 @@
|
||||||
|
|
||||||
<DropdownContainer />
|
<DropdownContainer />
|
||||||
<router-view />
|
<router-view />
|
||||||
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
import ConfirmDialog from "@/components/elements/ConfirmDialog";
|
import ConfirmDialog from "@/components/elements/ConfirmDialog";
|
||||||
import DropdownContainer from "@/components/elements/DropdownContainer";
|
import DropdownContainer from "@/components/elements/DropdownContainer";
|
||||||
|
import Loading from "@/components/Loading";
|
||||||
import Notifications from "@/components/Notifications";
|
import Notifications from "@/components/Notifications";
|
||||||
import Utils from "@/Utils";
|
import Utils from "@/Utils";
|
||||||
import Events from "@/Events";
|
import Events from "@/Events";
|
||||||
|
@ -29,6 +39,7 @@ export default {
|
||||||
ConfirmDialog,
|
ConfirmDialog,
|
||||||
DropdownContainer,
|
DropdownContainer,
|
||||||
Events,
|
Events,
|
||||||
|
Loading,
|
||||||
Notifications,
|
Notifications,
|
||||||
Ntfy,
|
Ntfy,
|
||||||
Pushbullet,
|
Pushbullet,
|
||||||
|
@ -41,6 +52,8 @@ export default {
|
||||||
userAuthenticated: false,
|
userAuthenticated: false,
|
||||||
connected: false,
|
connected: false,
|
||||||
pwaInstallEvent: null,
|
pwaInstallEvent: null,
|
||||||
|
initialized: false,
|
||||||
|
initError: null,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -84,8 +97,18 @@ export default {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
created() {
|
async created() {
|
||||||
this.initConfig()
|
try {
|
||||||
|
await this.initConfig()
|
||||||
|
} catch (e) {
|
||||||
|
const code = e?.response?.data?.code
|
||||||
|
if (![401, 403, 412].includes(code)) {
|
||||||
|
this.initError = e
|
||||||
|
console.error('Initialization error', e)
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
this.initialized = true
|
||||||
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
beforeMount() {
|
beforeMount() {
|
||||||
|
@ -125,6 +148,11 @@ html, body {
|
||||||
overflow: auto;
|
overflow: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#app-container {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
#app {
|
#app {
|
||||||
font-family: BlinkMacSystemFont,-apple-system,Avenir,Segoe UI,Roboto,Oxygen,Ubuntu,Cantarell,Fira Sans,Droid Sans,Helvetica Neue,Helvetica,Verdana,Arial,sans-serif;
|
font-family: BlinkMacSystemFont,-apple-system,Avenir,Segoe UI,Roboto,Oxygen,Ubuntu,Cantarell,Fira Sans,Droid Sans,Helvetica Neue,Helvetica,Verdana,Arial,sans-serif;
|
||||||
-webkit-font-smoothing: antialiased;
|
-webkit-font-smoothing: antialiased;
|
||||||
|
|
|
@ -8,6 +8,7 @@ $default-bg-6: #e4eae8 !default;
|
||||||
$default-bg-7: #e4e4e4 !default;
|
$default-bg-7: #e4e4e4 !default;
|
||||||
$ok-fg: #17ad17 !default;
|
$ok-fg: #17ad17 !default;
|
||||||
$error-fg: #ad1717 !default;
|
$error-fg: #ad1717 !default;
|
||||||
|
$error-bg: #ffaaa2 !default;
|
||||||
$tile-bg: linear-gradient(90deg, rgba(9,174,128,1) 0%, rgba(71,226,179,1) 120%);
|
$tile-bg: linear-gradient(90deg, rgba(9,174,128,1) 0%, rgba(71,226,179,1) 120%);
|
||||||
$tile-fg: white;
|
$tile-fg: white;
|
||||||
$tile-hover-bg: linear-gradient(90deg, rgba(41,216,159,1) 0%, rgba(9,188,138,1) 70%);
|
$tile-hover-bg: linear-gradient(90deg, rgba(41,216,159,1) 0%, rgba(9,188,138,1) 70%);
|
||||||
|
|
|
@ -42,18 +42,18 @@ export default {
|
||||||
// No users present -> redirect to the registration page
|
// No users present -> redirect to the registration page
|
||||||
if (
|
if (
|
||||||
error?.response?.data?.code === 412 &&
|
error?.response?.data?.code === 412 &&
|
||||||
window.location.href.indexOf('/register') < 0
|
window.location.pathname !== '/register'
|
||||||
) {
|
) {
|
||||||
window.location.href = '/register?redirect=' + window.location.href
|
window.location.href = '/register?redirect=' + window.location.href.split('/').slice(3).join('/')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unauthorized -> redirect to the login page
|
// Unauthorized -> redirect to the login page
|
||||||
if (
|
if (
|
||||||
error?.response?.data?.code === 401 &&
|
error?.response?.data?.code === 401 &&
|
||||||
window.location.href.indexOf('/login') < 0
|
window.location.pathname !== '/login'
|
||||||
) {
|
) {
|
||||||
window.location.href = '/login?redirect=' + window.location.href
|
window.location.href = '/login?redirect=' + window.location.href.split('/').slice(3).join('/')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
<template>
|
<template>
|
||||||
<div class="login-container">
|
<Loading v-if="!initialized" />
|
||||||
<form class="login" method="POST">
|
|
||||||
|
<div class="login-container" v-else>
|
||||||
|
<form class="login" method="POST" @submit="submitForm" v-if="!isAuthenticated">
|
||||||
<div class="header">
|
<div class="header">
|
||||||
<span class="logo">
|
<span class="logo">
|
||||||
<img src="/logo.svg" alt="logo" />
|
<img src="/logo.svg" alt="logo" />
|
||||||
|
@ -10,24 +12,30 @@
|
||||||
|
|
||||||
<div class="row">
|
<div class="row">
|
||||||
<label>
|
<label>
|
||||||
<input type="text" name="username" placeholder="Username">
|
<input type="text" name="username" :disabled="authenticating" placeholder="Username" ref="username">
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="row">
|
<div class="row">
|
||||||
<label>
|
<label>
|
||||||
<input type="password" name="password" placeholder="Password">
|
<input type="password" name="password" :disabled="authenticating" placeholder="Password">
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="row" v-if="_register">
|
<div class="row" v-if="register">
|
||||||
<label>
|
<label>
|
||||||
<input type="password" name="confirm_password" placeholder="Confirm password">
|
<input type="password" name="confirm_password" :disabled="authenticating" placeholder="Confirm password">
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="row buttons">
|
<div class="row buttons">
|
||||||
<input type="submit" class="btn btn-primary" :value="_register ? 'Register' : 'Login'">
|
<button type="submit"
|
||||||
|
class="btn btn-primary"
|
||||||
|
:class="{loading: authenticating}"
|
||||||
|
:disabled="authenticating">
|
||||||
|
<Loading v-if="authenticating" />
|
||||||
|
{{ register ? 'Register' : 'Login' }}
|
||||||
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="row pull-right">
|
<div class="row pull-right">
|
||||||
|
@ -36,16 +44,26 @@
|
||||||
Keep me logged in on this device
|
Keep me logged in on this device
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class="auth-error" v-if="authError">
|
||||||
|
{{ authError }}
|
||||||
|
</div>
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
|
import Loading from "@/components/Loading";
|
||||||
import Utils from "@/Utils";
|
import Utils from "@/Utils";
|
||||||
|
import axios from 'axios'
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
name: "Login",
|
name: "Login",
|
||||||
mixins: [Utils],
|
mixins: [Utils],
|
||||||
|
components: {
|
||||||
|
Loading,
|
||||||
|
},
|
||||||
|
|
||||||
props: {
|
props: {
|
||||||
// Set to true for a registration form, false for a login form
|
// Set to true for a registration form, false for a login form
|
||||||
register: {
|
register: {
|
||||||
|
@ -56,10 +74,86 @@ export default {
|
||||||
},
|
},
|
||||||
|
|
||||||
computed: {
|
computed: {
|
||||||
_register() {
|
redirect() {
|
||||||
return this.parseBoolean(this.register)
|
return this.$route.query.redirect?.length ? this.$route.query.redirect : '/'
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
data() {
|
||||||
|
return {
|
||||||
|
authError: null,
|
||||||
|
authenticating: false,
|
||||||
|
isAuthenticated: false,
|
||||||
|
initialized: false,
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
methods: {
|
||||||
|
async submitForm(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
const form = e.target
|
||||||
|
const data = new FormData(form)
|
||||||
|
const url = `/auth?type=${this.register ? 'register' : 'login'}`
|
||||||
|
|
||||||
|
if (this.register && data.get('password') !== data.get('confirm_password')) {
|
||||||
|
this.authError = "Passwords don't match"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
this.authError = null
|
||||||
|
|
||||||
|
try {
|
||||||
|
const authStatus = await axios.post(url, data)
|
||||||
|
const sessionToken = authStatus?.data?.session_token
|
||||||
|
if (sessionToken) {
|
||||||
|
const expiresAt = authStatus.expires_at ? Date.parse(authStatus.expires_at) : null
|
||||||
|
this.isAuthenticated = true
|
||||||
|
this.setCookie('session_token', sessionToken, {
|
||||||
|
expires: expiresAt,
|
||||||
|
})
|
||||||
|
window.location.href = authStatus.redirect || this.redirect
|
||||||
|
} else {
|
||||||
|
this.authError = "Invalid credentials"
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
this.authError = e.response.data.message || e.response.data.error
|
||||||
|
|
||||||
|
if (e.response?.status === 401) {
|
||||||
|
this.authError = this.authError || "Invalid credentials"
|
||||||
|
} else {
|
||||||
|
this.authError = this.authError || "An error occurred while processing the request"
|
||||||
|
if (e.response)
|
||||||
|
console.error(e.response.status, e.response.data)
|
||||||
|
else
|
||||||
|
console.error(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
async checkAuth() {
|
||||||
|
try {
|
||||||
|
const authStatus = await axios.get('/auth')
|
||||||
|
if (authStatus.data.session_token) {
|
||||||
|
this.isAuthenticated = true
|
||||||
|
window.location.href = authStatus.redirect || this.redirect
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
this.isAuthenticated = false
|
||||||
|
} finally {
|
||||||
|
this.initialized = true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
async created() {
|
||||||
|
await this.checkAuth()
|
||||||
|
},
|
||||||
|
|
||||||
|
async mounted() {
|
||||||
|
this.$nextTick(() => {
|
||||||
|
this.$refs.username?.focus()
|
||||||
|
})
|
||||||
|
},
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
@ -116,7 +210,7 @@ form {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
input[type=submit],
|
[type=submit],
|
||||||
input[type=password] {
|
input[type=password] {
|
||||||
border-radius: 1em;
|
border-radius: 1em;
|
||||||
}
|
}
|
||||||
|
@ -133,10 +227,33 @@ form {
|
||||||
.buttons {
|
.buttons {
|
||||||
text-align: center;
|
text-align: center;
|
||||||
|
|
||||||
input[type=submit] {
|
[type=submit] {
|
||||||
|
position: relative;
|
||||||
|
width: 6em;
|
||||||
|
height: 2.5em;
|
||||||
padding: .5em .75em;
|
padding: .5em .75em;
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
|
||||||
|
&.loading {
|
||||||
|
background: none;
|
||||||
|
border: none;
|
||||||
|
cursor: not-allowed;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.auth-error {
|
||||||
|
background: $error-bg;
|
||||||
|
display: flex;
|
||||||
|
margin: 1em 0 -2em 0;
|
||||||
|
padding: .5em;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
border: $notification-error-border;
|
||||||
|
border-radius: 1em;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a {
|
a {
|
||||||
|
|
|
@ -34,16 +34,17 @@ module.exports = {
|
||||||
|
|
||||||
devServer: {
|
devServer: {
|
||||||
proxy: {
|
proxy: {
|
||||||
|
'^/auth': httpProxy,
|
||||||
|
'^/camera/': httpProxy,
|
||||||
|
'^/execute': httpProxy,
|
||||||
|
'^/file': httpProxy,
|
||||||
|
'^/logo.svg': httpProxy,
|
||||||
|
'^/logout': httpProxy,
|
||||||
|
'^/media/': httpProxy,
|
||||||
|
'^/sound/': httpProxy,
|
||||||
'^/ws/events': wsProxy,
|
'^/ws/events': wsProxy,
|
||||||
'^/ws/requests': wsProxy,
|
'^/ws/requests': wsProxy,
|
||||||
'^/ws/shell': wsProxy,
|
'^/ws/shell': wsProxy,
|
||||||
'^/execute': httpProxy,
|
|
||||||
'^/file': httpProxy,
|
|
||||||
'^/auth': httpProxy,
|
|
||||||
'^/camera/': httpProxy,
|
|
||||||
'^/sound/': httpProxy,
|
|
||||||
'^/media/': httpProxy,
|
|
||||||
'^/logo.svg': httpProxy,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Optional, Tuple, Union
|
||||||
|
|
||||||
from platypush.plugins import Plugin, action
|
from platypush.plugins import Plugin, action
|
||||||
from platypush.user import UserManager
|
from platypush.user import UserManager
|
||||||
|
@ -66,15 +66,46 @@ class UserPlugin(Plugin):
|
||||||
}
|
}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def authenticate_user(self, username, password):
|
def authenticate_user(
|
||||||
|
self,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
code: Optional[str] = None,
|
||||||
|
return_details: bool = False,
|
||||||
|
) -> Union[bool, Tuple[bool, str]]:
|
||||||
"""
|
"""
|
||||||
Authenticate a user.
|
Authenticate a user.
|
||||||
|
|
||||||
:return: True if the provided username and password are correct, False
|
:param username: Username.
|
||||||
otherwise.
|
:param password: Password.
|
||||||
"""
|
:param code: Optional 2FA code, if 2FA is enabled for the user.
|
||||||
|
:param return_error_details: If True then return the error details in
|
||||||
|
case of authentication failure.
|
||||||
|
:return: If ``return_details`` is False (default), the action returns
|
||||||
|
True if the provided credentials are valid, False otherwise.
|
||||||
|
If ``return_details`` is True then the action returns a tuple
|
||||||
|
(authenticated, error_details) where ``authenticated`` is True if
|
||||||
|
the provided credentials are valid, False otherwise, and
|
||||||
|
``error_details`` is a string containing the error details in case
|
||||||
|
of authentication failure. Supported error details are:
|
||||||
|
|
||||||
return bool(self.user_manager.authenticate_user(username, password))
|
- ``invalid_credentials``: Invalid username or password.
|
||||||
|
- ``invalid_otp_code``: Invalid 2FA code.
|
||||||
|
- ``missing_otp_code``: Username/password are correct, but a 2FA
|
||||||
|
code is required for the user.
|
||||||
|
|
||||||
|
"""
|
||||||
|
response = self.user_manager.authenticate_user(
|
||||||
|
username, password, code=code, return_error=return_details
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_details:
|
||||||
|
assert (
|
||||||
|
isinstance(response, tuple) and len(response) == 2
|
||||||
|
), 'Invalid response from authenticate_user'
|
||||||
|
return response[0], response[1].value
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def update_password(self, username, old_password, new_password):
|
def update_password(self, username, old_password, new_password):
|
||||||
|
@ -111,7 +142,7 @@ class UserPlugin(Plugin):
|
||||||
return None, "No such user: {}".format(username)
|
return None, "No such user: {}".format(username)
|
||||||
|
|
||||||
@action
|
@action
|
||||||
def create_session(self, username, password, expires_at=None):
|
def create_session(self, username, password, code=None, expires_at=None):
|
||||||
"""
|
"""
|
||||||
Create a user session.
|
Create a user session.
|
||||||
|
|
||||||
|
@ -130,7 +161,7 @@ class UserPlugin(Plugin):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
session = self.user_manager.create_user_session(
|
session = self.user_manager.create_user_session(
|
||||||
username=username, password=password, expires_at=expires_at
|
username=username, password=password, code=code, expires_at=expires_at
|
||||||
)
|
)
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
|
@ -140,9 +171,9 @@ class UserPlugin(Plugin):
|
||||||
'session_token': session.session_token,
|
'session_token': session.session_token,
|
||||||
'user_id': session.user_id,
|
'user_id': session.user_id,
|
||||||
'created_at': session.created_at.isoformat(),
|
'created_at': session.created_at.isoformat(),
|
||||||
'expires_at': session.expires_at.isoformat()
|
'expires_at': (
|
||||||
if session.expires_at
|
session.expires_at.isoformat() if session.expires_at else None # type: ignore
|
||||||
else None,
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import base64
|
import base64
|
||||||
import datetime
|
import datetime
|
||||||
|
import enum
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Dict
|
from typing import List, Optional, Dict, Tuple, Union
|
||||||
|
|
||||||
import rsa
|
import rsa
|
||||||
|
|
||||||
|
@ -13,12 +14,32 @@ from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||||
from sqlalchemy.orm import make_transient
|
from sqlalchemy.orm import make_transient
|
||||||
|
|
||||||
from platypush.common.db import Base
|
from platypush.common.db import Base
|
||||||
|
from platypush.config import Config
|
||||||
from platypush.context import get_plugin
|
from platypush.context import get_plugin
|
||||||
from platypush.exceptions.user import (
|
from platypush.exceptions.user import (
|
||||||
InvalidJWTTokenException,
|
InvalidJWTTokenException,
|
||||||
InvalidCredentialsException,
|
InvalidCredentialsException,
|
||||||
)
|
)
|
||||||
from platypush.utils import get_or_generate_jwt_rsa_key_pair, utcnow
|
from platypush.utils import get_or_generate_stored_rsa_key_pair, utcnow
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationStatus(enum.Enum):
|
||||||
|
"""
|
||||||
|
Enum for authentication errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
OK = ''
|
||||||
|
INVALID_AUTH_TYPE = 'invalid_auth_type'
|
||||||
|
INVALID_CREDENTIALS = 'invalid_credentials'
|
||||||
|
INVALID_METHOD = 'invalid_method'
|
||||||
|
INVALID_JWT_TOKEN = 'invalid_jwt_token'
|
||||||
|
INVALID_OTP_CODE = 'invalid_otp_code'
|
||||||
|
MISSING_OTP_CODE = 'missing_otp_code'
|
||||||
|
MISSING_PASSWORD = 'missing_password'
|
||||||
|
MISSING_USERNAME = 'missing_username'
|
||||||
|
PASSWORD_MISMATCH = 'password_mismatch'
|
||||||
|
REGISTRATION_DISABLED = 'registration_disabled'
|
||||||
|
REGISTRATION_REQUIRED = 'registration_required'
|
||||||
|
|
||||||
|
|
||||||
class UserManager:
|
class UserManager:
|
||||||
|
@ -26,6 +47,14 @@ class UserManager:
|
||||||
Main class for managing platform users
|
Main class for managing platform users
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_otp_workdir = os.path.join(Config.get_workdir(), 'otp')
|
||||||
|
_otp_keyfile = os.path.join(_otp_workdir, 'key')
|
||||||
|
_otp_keyfile_pub = f'{_otp_keyfile}.pub'
|
||||||
|
|
||||||
|
_jwt_workdir = os.path.join(Config.get_workdir(), 'jwt')
|
||||||
|
_jwt_keyfile = os.path.join(_jwt_workdir, 'id_rsa')
|
||||||
|
_jwt_keyfile_pub = f'{_jwt_keyfile}.pub'
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
db_plugin = get_plugin('db')
|
db_plugin = get_plugin('db')
|
||||||
assert db_plugin, 'Database plugin not configured'
|
assert db_plugin, 'Database plugin not configured'
|
||||||
|
@ -41,6 +70,44 @@ class UserManager:
|
||||||
def _get_session(self, *args, **kwargs):
|
def _get_session(self, *args, **kwargs):
|
||||||
return self.db.get_session(self.db.get_engine(), *args, **kwargs)
|
return self.db.get_session(self.db.get_engine(), *args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_jwt_rsa_key_pair(cls):
|
||||||
|
"""
|
||||||
|
Get or generate the JWT RSA key pair.
|
||||||
|
"""
|
||||||
|
return get_or_generate_stored_rsa_key_pair(cls._jwt_keyfile, size=2048)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_or_generate_otp_rsa_key_pair(cls):
|
||||||
|
"""
|
||||||
|
Get or generate the OTP RSA key pair.
|
||||||
|
"""
|
||||||
|
return get_or_generate_stored_rsa_key_pair(cls._otp_keyfile, size=4096)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _encrypt(data: Union[str, bytes, dict, list, tuple], key: rsa.PublicKey) -> str:
|
||||||
|
"""
|
||||||
|
Encrypt the data using the given RSA public key.
|
||||||
|
"""
|
||||||
|
if isinstance(data, tuple):
|
||||||
|
data = list(data)
|
||||||
|
if isinstance(data, (dict, list)):
|
||||||
|
data = json.dumps(data, sort_keys=True, indent=None)
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = data.encode('ascii')
|
||||||
|
|
||||||
|
return base64.b64encode(rsa.encrypt(data, key)).decode()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decrypt(data: Union[str, bytes], key: rsa.PrivateKey) -> str:
|
||||||
|
"""
|
||||||
|
Decrypt the data using the given RSA private key.
|
||||||
|
"""
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = data.encode('ascii')
|
||||||
|
|
||||||
|
return rsa.decrypt(base64.b64decode(data), key).decode()
|
||||||
|
|
||||||
def get_user(self, username):
|
def get_user(self, username):
|
||||||
with self._get_session() as session:
|
with self._get_session() as session:
|
||||||
user = self._get_user(session, username)
|
user = self._get_user(session, username)
|
||||||
|
@ -88,9 +155,9 @@ class UserManager:
|
||||||
|
|
||||||
return self._mask_password(user)
|
return self._mask_password(user)
|
||||||
|
|
||||||
def update_password(self, username, old_password, new_password):
|
def update_password(self, username, old_password, new_password, code=None):
|
||||||
with self._get_session(locked=True) as session:
|
with self._get_session(locked=True) as session:
|
||||||
if not self._authenticate_user(session, username, old_password):
|
if not self._authenticate_user(session, username, old_password, code=code):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
user = self._get_user(session, username)
|
user = self._get_user(session, username)
|
||||||
|
@ -103,12 +170,22 @@ class UserManager:
|
||||||
session.commit()
|
session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def authenticate_user(self, username, password):
|
def authenticate_user(self, username, password, code=None, return_error=False):
|
||||||
with self._get_session() as session:
|
with self._get_session() as session:
|
||||||
return self._authenticate_user(session, username, password)
|
return self._authenticate_user(
|
||||||
|
session, username, password, code=code, return_error=return_error
|
||||||
|
)
|
||||||
|
|
||||||
def authenticate_user_session(self, session_token):
|
def authenticate_user_session(self, session_token, with_error=False):
|
||||||
with self._get_session() as session:
|
with self._get_session() as session:
|
||||||
|
users_count = session.query(User).count()
|
||||||
|
if not users_count:
|
||||||
|
return (
|
||||||
|
(None, None, AuthenticationStatus.REGISTRATION_REQUIRED)
|
||||||
|
if with_error
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
|
|
||||||
user_session = (
|
user_session = (
|
||||||
session.query(UserSession)
|
session.query(UserSession)
|
||||||
.filter_by(session_token=session_token)
|
.filter_by(session_token=session_token)
|
||||||
|
@ -122,10 +199,21 @@ class UserManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
if not user_session or (expires_at and expires_at < utcnow()):
|
if not user_session or (expires_at and expires_at < utcnow()):
|
||||||
return None, None
|
return (
|
||||||
|
(None, None, AuthenticationStatus.INVALID_CREDENTIALS)
|
||||||
|
if with_error
|
||||||
|
else (None, None)
|
||||||
|
)
|
||||||
|
|
||||||
user = session.query(User).filter_by(user_id=user_session.user_id).first()
|
user = session.query(User).filter_by(user_id=user_session.user_id).first()
|
||||||
return self._mask_password(user), user_session
|
return (
|
||||||
|
(self._mask_password(user), user_session, AuthenticationStatus.OK)
|
||||||
|
if with_error
|
||||||
|
else (
|
||||||
|
self._mask_password(user),
|
||||||
|
user_session,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def delete_user(self, username):
|
def delete_user(self, username):
|
||||||
with self._get_session(locked=True) as session:
|
with self._get_session(locked=True) as session:
|
||||||
|
@ -158,11 +246,25 @@ class UserManager:
|
||||||
session.commit()
|
session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def create_user_session(self, username, password, expires_at=None):
|
def create_user_session(
|
||||||
|
self,
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
code=None,
|
||||||
|
expires_at=None,
|
||||||
|
error_on_invalid_credentials=False,
|
||||||
|
):
|
||||||
with self._get_session(locked=True) as session:
|
with self._get_session(locked=True) as session:
|
||||||
user = self._authenticate_user(session, username, password)
|
user, status = self._authenticate_user( # type: ignore
|
||||||
|
session,
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
code=code,
|
||||||
|
return_error=error_on_invalid_credentials,
|
||||||
|
)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None if not error_on_invalid_credentials else (None, status)
|
||||||
|
|
||||||
if expires_at:
|
if expires_at:
|
||||||
if isinstance(expires_at, (int, float)):
|
if isinstance(expires_at, (int, float)):
|
||||||
|
@ -180,7 +282,53 @@ class UserManager:
|
||||||
|
|
||||||
session.add(user_session)
|
session.add(user_session)
|
||||||
session.commit()
|
session.commit()
|
||||||
return user_session
|
return user_session, (
|
||||||
|
AuthenticationStatus.OK if not error_on_invalid_credentials else status
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_otp_secret(
|
||||||
|
self, username: str, expires_at: Optional[datetime.datetime] = None
|
||||||
|
):
|
||||||
|
pubkey, _ = self._get_or_generate_otp_rsa_key_pair()
|
||||||
|
|
||||||
|
# Generate a new OTP secret and encrypt it with the OTP RSA key pair
|
||||||
|
otp_secret = "".join(
|
||||||
|
random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(32)
|
||||||
|
)
|
||||||
|
|
||||||
|
encrypted_secret = self._encrypt(otp_secret, pubkey)
|
||||||
|
|
||||||
|
with self._get_session(locked=True) as session:
|
||||||
|
user = self._get_user(session, username)
|
||||||
|
assert user, f'No such user: {username}'
|
||||||
|
|
||||||
|
# Create a new OTP secret
|
||||||
|
user_otp = UserOtp(
|
||||||
|
user_id=user.user_id,
|
||||||
|
otp_secret=encrypted_secret,
|
||||||
|
created_at=utcnow(),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove any existing OTP secret and replace it with the new one
|
||||||
|
session.query(UserOtp).filter_by(user_id=user.user_id).delete()
|
||||||
|
session.add(user_otp)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return user_otp
|
||||||
|
|
||||||
|
def get_otp_secret(self, username: str) -> Optional[str]:
|
||||||
|
with self._get_session() as session:
|
||||||
|
user = self._get_user(session, username)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_otp = session.query(UserOtp).filter_by(user_id=user.user_id).first()
|
||||||
|
if not user_otp:
|
||||||
|
return None
|
||||||
|
|
||||||
|
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
|
||||||
|
return self._decrypt(user_otp.otp_secret, priv_key)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_user(session, username):
|
def _get_user(session, username):
|
||||||
|
@ -268,20 +416,17 @@ class UserManager:
|
||||||
if not user:
|
if not user:
|
||||||
raise InvalidCredentialsException()
|
raise InvalidCredentialsException()
|
||||||
|
|
||||||
pub_key, _ = get_or_generate_jwt_rsa_key_pair()
|
pub_key, _ = self._get_jwt_rsa_key_pair()
|
||||||
payload = json.dumps(
|
return self._encrypt(
|
||||||
{
|
{
|
||||||
'username': username,
|
'username': username,
|
||||||
'password': password,
|
'password': password,
|
||||||
'created_at': datetime.datetime.now().timestamp(),
|
'created_at': datetime.datetime.now().timestamp(),
|
||||||
'expires_at': expires_at.timestamp() if expires_at else None,
|
'expires_at': expires_at.timestamp() if expires_at else None,
|
||||||
},
|
},
|
||||||
sort_keys=True,
|
pub_key,
|
||||||
indent=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return base64.b64encode(rsa.encrypt(payload.encode('ascii'), pub_key)).decode()
|
|
||||||
|
|
||||||
def validate_jwt_token(self, token: str) -> Dict[str, str]:
|
def validate_jwt_token(self, token: str) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Validate a JWT token.
|
Validate a JWT token.
|
||||||
|
@ -299,14 +444,10 @@ class UserManager:
|
||||||
|
|
||||||
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
|
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
|
||||||
"""
|
"""
|
||||||
_, priv_key = get_or_generate_jwt_rsa_key_pair()
|
_, priv_key = self._get_jwt_rsa_key_pair()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = json.loads(
|
payload = json.loads(self._decrypt(token, priv_key))
|
||||||
rsa.decrypt(base64.b64decode(token.encode('ascii')), priv_key).decode(
|
|
||||||
'ascii'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except (TypeError, ValueError) as e:
|
except (TypeError, ValueError) as e:
|
||||||
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e
|
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e
|
||||||
|
|
||||||
|
@ -323,23 +464,160 @@ class UserManager:
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def _authenticate_user(self, session, username, password):
|
def _authenticate_user(
|
||||||
|
self,
|
||||||
|
session,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
code: Optional[str] = None,
|
||||||
|
return_error: bool = False,
|
||||||
|
) -> Union[Optional['User'], Tuple[Optional['User'], 'AuthenticationStatus']]:
|
||||||
"""
|
"""
|
||||||
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.
|
:return: :class:`platypush.user.User` instance if the user exists and
|
||||||
|
the password is valid, ``None`` otherwise.
|
||||||
"""
|
"""
|
||||||
user = self._get_user(session, username)
|
user = self._get_user(session, username)
|
||||||
if not user:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
# The user does not exist
|
||||||
|
if not user:
|
||||||
|
return (
|
||||||
|
None
|
||||||
|
if not return_error
|
||||||
|
else (None, AuthenticationStatus.INVALID_CREDENTIALS)
|
||||||
|
)
|
||||||
|
|
||||||
|
# The password is not correct
|
||||||
if not self._check_password(
|
if not self._check_password(
|
||||||
password,
|
password,
|
||||||
user.password,
|
user.password,
|
||||||
bytes.fromhex(user.password_salt) if user.password_salt else None,
|
bytes.fromhex(user.password_salt) if user.password_salt else None,
|
||||||
user.hmac_iterations,
|
user.hmac_iterations,
|
||||||
):
|
):
|
||||||
return None
|
return (
|
||||||
|
None
|
||||||
|
if not return_error
|
||||||
|
else (None, AuthenticationStatus.INVALID_CREDENTIALS)
|
||||||
|
)
|
||||||
|
|
||||||
return user
|
otp_secret = self.get_otp_secret(username)
|
||||||
|
|
||||||
|
# The user doesn't have 2FA enabled and the password is correct:
|
||||||
|
# authentication successful
|
||||||
|
if not otp_secret:
|
||||||
|
return user if not return_error else (user, AuthenticationStatus.OK)
|
||||||
|
|
||||||
|
# The user has 2FA enabled but the code is missing
|
||||||
|
if not code:
|
||||||
|
return (
|
||||||
|
None
|
||||||
|
if not return_error
|
||||||
|
else (None, AuthenticationStatus.MISSING_OTP_CODE)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.validate_otp_code(username, code):
|
||||||
|
return user if not return_error else (user, AuthenticationStatus.OK)
|
||||||
|
|
||||||
|
if not self.validate_backup_code(username, code):
|
||||||
|
return (
|
||||||
|
None
|
||||||
|
if not return_error
|
||||||
|
else (None, AuthenticationStatus.INVALID_OTP_CODE)
|
||||||
|
)
|
||||||
|
|
||||||
|
return user if not return_error else (user, AuthenticationStatus.OK)
|
||||||
|
|
||||||
|
def refresh_user_backup_codes(self, username: str):
|
||||||
|
"""
|
||||||
|
Refresh the backup codes for a user with 2FA enabled.
|
||||||
|
"""
|
||||||
|
with self._get_session(locked=True) as session:
|
||||||
|
user = self._get_user(session, username)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
|
||||||
|
session.query(UserBackupCode).filter_by(user_id=user.user_id).delete()
|
||||||
|
pub_key, _ = self._get_or_generate_otp_rsa_key_pair()
|
||||||
|
|
||||||
|
for _ in range(10):
|
||||||
|
backup_code = "".join(
|
||||||
|
random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(16)
|
||||||
|
)
|
||||||
|
|
||||||
|
user_backup_code = UserBackupCode(
|
||||||
|
user_id=user.user_id,
|
||||||
|
code=self._encrypt(backup_code, pub_key),
|
||||||
|
created_at=utcnow(),
|
||||||
|
expires_at=utcnow() + datetime.timedelta(days=30),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(user_backup_code)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_user_backup_codes(self, username: str) -> List['UserBackupCode']:
|
||||||
|
with self._get_session() as session:
|
||||||
|
user = self._get_user(session, username)
|
||||||
|
if not user:
|
||||||
|
return []
|
||||||
|
|
||||||
|
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
|
||||||
|
codes = session.query(UserBackupCode).filter_by(user_id=user.user_id).all()
|
||||||
|
|
||||||
|
for code in codes:
|
||||||
|
code.code = self._decrypt(code.code, priv_key)
|
||||||
|
|
||||||
|
return codes
|
||||||
|
|
||||||
|
def validate_backup_code(self, username: str, code: str) -> bool:
|
||||||
|
with self._get_session() as session:
|
||||||
|
user = self._get_user(session, username)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
|
||||||
|
pub_key, _ = self._get_or_generate_otp_rsa_key_pair()
|
||||||
|
user_backup_code = (
|
||||||
|
session.query(UserBackupCode)
|
||||||
|
.filter_by(user_id=user.user_id, code=self._encrypt(code, pub_key))
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_backup_code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
session.delete(user_backup_code)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_otp_code(self, username: str, code: str) -> bool:
|
||||||
|
otp = get_plugin('otp')
|
||||||
|
assert otp
|
||||||
|
|
||||||
|
with self._get_session() as session:
|
||||||
|
user = self._get_user(session, username)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
|
||||||
|
otp_secret = self.get_otp_secret(username)
|
||||||
|
if not otp_secret:
|
||||||
|
return False
|
||||||
|
|
||||||
|
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
|
||||||
|
otp_secret = self._decrypt(otp_secret, priv_key)
|
||||||
|
|
||||||
|
return otp.verify_time_otp(otp=code, secret=otp_secret)
|
||||||
|
|
||||||
|
def disable_mfa(self, username: str):
|
||||||
|
with self._get_session(locked=True) as session:
|
||||||
|
user = self._get_user(session, username)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
|
||||||
|
session.query(UserOtp).filter_by(user_id=user.user_id).delete()
|
||||||
|
session.query(UserBackupCode).filter_by(user_id=user.user_id).delete()
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
|
@ -370,4 +648,31 @@ class UserSession(Base):
|
||||||
expires_at = Column(DateTime)
|
expires_at = Column(DateTime)
|
||||||
|
|
||||||
|
|
||||||
|
class UserOtp(Base):
|
||||||
|
"""
|
||||||
|
Models the UserOtp table, which contains the OTP secrets for each user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = 'user_otp'
|
||||||
|
|
||||||
|
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
|
||||||
|
otp_secret = Column(String, nullable=False, unique=True)
|
||||||
|
created_at = Column(DateTime)
|
||||||
|
expires_at = Column(DateTime)
|
||||||
|
|
||||||
|
|
||||||
|
class UserBackupCode(Base):
|
||||||
|
"""
|
||||||
|
Models the UserBackupCode table, which contains the backup codes for each
|
||||||
|
user with 2FA enabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = 'user_backup_code'
|
||||||
|
|
||||||
|
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
|
||||||
|
code = Column(String, nullable=False, unique=True)
|
||||||
|
created_at = Column(DateTime)
|
||||||
|
expires_at = Column(DateTime)
|
||||||
|
|
||||||
|
|
||||||
# vim:sw=4:ts=4:et:
|
# vim:sw=4:ts=4:et:
|
||||||
|
|
|
@ -532,7 +532,12 @@ def generate_rsa_key_pair(
|
||||||
private_key_str = priv_key.save_pkcs1('PEM').decode()
|
private_key_str = priv_key.save_pkcs1('PEM').decode()
|
||||||
|
|
||||||
if key_file:
|
if key_file:
|
||||||
|
pathlib.Path(os.path.dirname(os.path.expanduser(key_file))).mkdir(
|
||||||
|
parents=True, exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
logger.info('Saving private key to %s', key_file)
|
logger.info('Saving private key to %s', key_file)
|
||||||
|
|
||||||
with open(os.path.expanduser(key_file), 'w') as f1, open(
|
with open(os.path.expanduser(key_file), 'w') as f1, open(
|
||||||
os.path.expanduser(key_file) + '.pub', 'w'
|
os.path.expanduser(key_file) + '.pub', 'w'
|
||||||
) as f2:
|
) as f2:
|
||||||
|
@ -543,14 +548,20 @@ def generate_rsa_key_pair(
|
||||||
return pub_key, priv_key
|
return pub_key, priv_key
|
||||||
|
|
||||||
|
|
||||||
def get_or_generate_jwt_rsa_key_pair():
|
def get_or_generate_stored_rsa_key_pair(
|
||||||
|
keyfile: str, size: int = 2048
|
||||||
|
) -> Tuple[PublicKey, PrivateKey]:
|
||||||
"""
|
"""
|
||||||
Get or generate a JWT RSA key pair.
|
Get or generate an RSA key pair and store it in the given key file.
|
||||||
"""
|
|
||||||
from platypush.config import Config
|
|
||||||
|
|
||||||
key_dir = os.path.join(Config.get_workdir(), 'jwt')
|
The private key will be stored in the given file, while the public key will
|
||||||
priv_key_file = os.path.join(key_dir, 'id_rsa')
|
be stored in ``<keyfile>.pub``.
|
||||||
|
|
||||||
|
:param keyfile: Path to the key file.
|
||||||
|
:param size: Key size in bits (default: 2048).
|
||||||
|
"""
|
||||||
|
keydir = os.path.dirname(os.path.expanduser(keyfile))
|
||||||
|
priv_key_file = os.path.join(keydir, os.path.basename(keyfile))
|
||||||
pub_key_file = priv_key_file + '.pub'
|
pub_key_file = priv_key_file + '.pub'
|
||||||
|
|
||||||
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
|
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
|
||||||
|
@ -560,8 +571,15 @@ def get_or_generate_jwt_rsa_key_pair():
|
||||||
PrivateKey.load_pkcs1(f2.read().encode()),
|
PrivateKey.load_pkcs1(f2.read().encode()),
|
||||||
)
|
)
|
||||||
|
|
||||||
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
pub_key, priv_key = generate_rsa_key_pair(priv_key_file, size=size)
|
||||||
return generate_rsa_key_pair(priv_key_file, size=2048)
|
pathlib.Path(keydir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
||||||
|
|
||||||
|
with open(pub_key_file, 'w') as f1, open(priv_key_file, 'w') as f2:
|
||||||
|
f1.write(pub_key.save_pkcs1('PEM').decode())
|
||||||
|
f2.write(priv_key.save_pkcs1('PEM').decode())
|
||||||
|
os.chmod(priv_key_file, 0o600)
|
||||||
|
|
||||||
|
return pub_key, priv_key
|
||||||
|
|
||||||
|
|
||||||
def get_enabled_plugins() -> dict:
|
def get_enabled_plugins() -> dict:
|
||||||
|
|
Loading…
Reference in a new issue