forked from platypush/platypush
[core] Refactored Web login/registration layer.
Instead of having a single Flask-provided endpoint, the UI should initialize its own Vue component and manage the authentication asynchronously over API. This is especially a requirement for the implementation of 2FA. The following routes have also been merged/refactored: - `POST /register` -> `POST /auth?type=register` - `POST /login` -> `POST /auth?type=login` - `POST /auth` -> `POST /auth?type=jwt`
This commit is contained in:
parent
8904e40f9f
commit
ee27b2c4c6
14 changed files with 839 additions and 262 deletions
|
@ -4,8 +4,10 @@ import logging
|
|||
|
||||
from flask import Blueprint, request, abort, jsonify
|
||||
|
||||
from platypush.backend.http.app.utils.auth import UserAuthStatus
|
||||
from platypush.exceptions.user import UserException
|
||||
from platypush.user import UserManager
|
||||
from platypush.utils import utcnow
|
||||
|
||||
auth = Blueprint('auth', __name__)
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -16,39 +18,24 @@ __routes__ = [
|
|||
]
|
||||
|
||||
|
||||
@auth.route('/auth', methods=['POST'])
|
||||
def auth_endpoint():
|
||||
"""
|
||||
Authentication endpoint. It validates the user credentials provided over a JSON payload with the following
|
||||
structure:
|
||||
def _dump_session(session, redirect_page='/'):
|
||||
return jsonify(
|
||||
{
|
||||
'status': 'ok',
|
||||
'user_id': session.user_id,
|
||||
'session_token': session.session_token,
|
||||
'expires_at': session.expires_at,
|
||||
'redirect': redirect_page,
|
||||
}
|
||||
)
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"username": "USERNAME",
|
||||
"password": "PASSWORD",
|
||||
"expiry_days": "The generated token should be valid for these many days"
|
||||
}
|
||||
|
||||
``expiry_days`` is optional, and if omitted or set to zero the token will be valid indefinitely.
|
||||
|
||||
Upon successful validation, a new JWT token will be generated using the service's self-generated RSA key-pair and it
|
||||
will be returned to the user. The token can then be used to authenticate API calls to ``/execute`` by setting the
|
||||
``Authorization: Bearer <TOKEN_HERE>`` header upon HTTP calls.
|
||||
|
||||
:return: Return structure:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"token": "<generated token here>"
|
||||
}
|
||||
"""
|
||||
def _jwt_auth():
|
||||
try:
|
||||
payload = json.loads(request.get_data(as_text=True))
|
||||
username, password = payload['username'], payload['password']
|
||||
except Exception as e:
|
||||
log.warning('Invalid payload passed to the auth endpoint: ' + str(e))
|
||||
except Exception:
|
||||
log.warning('Invalid payload passed to the auth endpoint')
|
||||
abort(400)
|
||||
|
||||
expiry_days = payload.get('expiry_days')
|
||||
|
@ -59,8 +46,174 @@ def auth_endpoint():
|
|||
user_manager = UserManager()
|
||||
|
||||
try:
|
||||
return jsonify({
|
||||
'token': user_manager.generate_jwt_token(username=username, password=password, expires_at=expires_at),
|
||||
})
|
||||
return jsonify(
|
||||
{
|
||||
'token': user_manager.generate_jwt_token(
|
||||
username=username, password=password, expires_at=expires_at
|
||||
),
|
||||
}
|
||||
)
|
||||
except UserException as e:
|
||||
abort(401, str(e))
|
||||
|
||||
|
||||
def _session_auth():
|
||||
user_manager = UserManager()
|
||||
session_token = request.cookies.get('session_token')
|
||||
redirect_page = request.args.get('redirect') or '/'
|
||||
|
||||
if session_token:
|
||||
user, session = user_manager.authenticate_user_session(session_token) # type: ignore
|
||||
if user and session:
|
||||
return _dump_session(session, redirect_page)
|
||||
|
||||
if request.form:
|
||||
username = request.form.get('username')
|
||||
password = request.form.get('password')
|
||||
code = request.form.get('code')
|
||||
remember = request.form.get('remember')
|
||||
expires = utcnow() + datetime.timedelta(days=365) if remember else None
|
||||
session, status = user_manager.create_user_session( # type: ignore
|
||||
username=username,
|
||||
password=password,
|
||||
code=code,
|
||||
expires_at=expires,
|
||||
error_on_invalid_credentials=True,
|
||||
)
|
||||
|
||||
if session:
|
||||
return _dump_session(session, redirect_page)
|
||||
|
||||
if status:
|
||||
return status.to_response() # type: ignore
|
||||
|
||||
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
|
||||
|
||||
|
||||
def _register_route():
|
||||
"""Registration endpoint"""
|
||||
user_manager = UserManager()
|
||||
session_token = request.cookies.get('session_token')
|
||||
redirect_page = request.args.get('redirect') or '/'
|
||||
|
||||
if session_token:
|
||||
user, session = user_manager.authenticate_user_session(session_token) # type: ignore
|
||||
if user and session:
|
||||
return _dump_session(session, redirect_page)
|
||||
|
||||
if user_manager.get_user_count() > 0:
|
||||
return UserAuthStatus.REGISTRATION_DISABLED.to_response()
|
||||
|
||||
if not request.form:
|
||||
return UserAuthStatus.MISSING_USERNAME.to_response()
|
||||
|
||||
username = request.form.get('username')
|
||||
password = request.form.get('password')
|
||||
confirm_password = request.form.get('confirm_password')
|
||||
remember = request.form.get('remember')
|
||||
|
||||
if not username:
|
||||
return UserAuthStatus.MISSING_USERNAME.to_response()
|
||||
if not password:
|
||||
return UserAuthStatus.MISSING_PASSWORD.to_response()
|
||||
if password != confirm_password:
|
||||
return UserAuthStatus.PASSWORD_MISMATCH.to_response()
|
||||
|
||||
user_manager.create_user(username=username, password=password)
|
||||
session, status = user_manager.create_user_session( # type: ignore
|
||||
username=username,
|
||||
password=password,
|
||||
expires_at=(utcnow() + datetime.timedelta(days=365) if remember else None),
|
||||
error_on_invalid_credentials=True,
|
||||
)
|
||||
|
||||
if session:
|
||||
return _dump_session(session, redirect_page)
|
||||
|
||||
if status:
|
||||
return status.to_response() # type: ignore
|
||||
|
||||
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
|
||||
|
||||
|
||||
def _auth_get():
|
||||
"""
|
||||
Get the current authentication status of the user session.
|
||||
"""
|
||||
user_manager = UserManager()
|
||||
session_token = request.cookies.get('session_token')
|
||||
redirect_page = request.args.get('redirect') or '/'
|
||||
user, session, status = user_manager.authenticate_user_session( # type: ignore
|
||||
session_token, with_error=True
|
||||
)
|
||||
|
||||
if user and session:
|
||||
return _dump_session(session, redirect_page)
|
||||
|
||||
if status:
|
||||
return UserAuthStatus.by_status(status).to_response() # type: ignore
|
||||
|
||||
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
|
||||
|
||||
|
||||
def _auth_post():
|
||||
"""
|
||||
Authenticate the user session.
|
||||
"""
|
||||
auth_type = request.args.get('type') or 'jwt'
|
||||
|
||||
if auth_type == 'jwt':
|
||||
return _jwt_auth()
|
||||
|
||||
if auth_type == 'register':
|
||||
return _register_route()
|
||||
|
||||
if auth_type == 'login':
|
||||
return _session_auth()
|
||||
|
||||
return UserAuthStatus.INVALID_AUTH_TYPE.to_response()
|
||||
|
||||
|
||||
@auth.route('/auth', methods=['GET', 'POST'])
|
||||
def auth_endpoint():
|
||||
"""
|
||||
Authentication endpoint. It validates the user credentials provided over a
|
||||
JSON payload with the following structure:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"username": "USERNAME",
|
||||
"password": "PASSWORD",
|
||||
"code": "2FA_CODE",
|
||||
"expiry_days": "The generated token should be valid for these many days"
|
||||
}
|
||||
|
||||
``expiry_days`` is optional, and if omitted or set to zero the token will
|
||||
be valid indefinitely.
|
||||
|
||||
Upon successful validation, a new JWT token will be generated using the
|
||||
service's self-generated RSA key-pair and it will be returned to the user.
|
||||
The token can then be used to authenticate API calls to ``/execute`` by
|
||||
setting the ``Authorization: Bearer <TOKEN_HERE>`` header upon HTTP calls.
|
||||
|
||||
:return: Return structure:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"token": "<generated token here>"
|
||||
}
|
||||
"""
|
||||
if request.method == 'GET':
|
||||
return _auth_get()
|
||||
|
||||
if request.method == 'POST':
|
||||
return _auth_post()
|
||||
|
||||
return UserAuthStatus.INVALID_METHOD.to_response()
|
||||
|
||||
|
||||
# from flask import Blueprint, request, redirect, render_template, make_response, abort
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -1,56 +0,0 @@
|
|||
import datetime
|
||||
import re
|
||||
|
||||
from flask import Blueprint, request, redirect, render_template, make_response
|
||||
|
||||
from platypush.backend.http.app import template_folder
|
||||
from platypush.backend.http.utils import HttpUtils
|
||||
from platypush.user import UserManager
|
||||
from platypush.utils import utcnow
|
||||
|
||||
login = Blueprint('login', __name__, template_folder=template_folder)
|
||||
|
||||
# Declare routes list
|
||||
__routes__ = [
|
||||
login,
|
||||
]
|
||||
|
||||
|
||||
@login.route('/login', methods=['GET', 'POST'])
|
||||
def login():
|
||||
"""Login page"""
|
||||
user_manager = UserManager()
|
||||
session_token = request.cookies.get('session_token')
|
||||
|
||||
redirect_page = request.args.get('redirect')
|
||||
if not redirect_page:
|
||||
redirect_page = request.headers.get('Referer', '/')
|
||||
if re.search('(^https?://[^/]+)?/login[^?#]?', redirect_page):
|
||||
# Prevent redirect loop
|
||||
redirect_page = '/'
|
||||
|
||||
if session_token:
|
||||
user, session = user_manager.authenticate_user_session(session_token)
|
||||
if user:
|
||||
return redirect(redirect_page, 302) # lgtm [py/url-redirection]
|
||||
|
||||
if request.form:
|
||||
username = request.form.get('username')
|
||||
password = request.form.get('password')
|
||||
remember = request.form.get('remember')
|
||||
expires = utcnow() + datetime.timedelta(days=365) if remember else None
|
||||
|
||||
session = user_manager.create_user_session(
|
||||
username=username, password=password, expires_at=expires
|
||||
)
|
||||
|
||||
if session:
|
||||
redirect_target = redirect(redirect_page, 302) # lgtm [py/url-redirection]
|
||||
response = make_response(redirect_target)
|
||||
response.set_cookie('session_token', session.session_token, expires=expires)
|
||||
return response
|
||||
|
||||
return render_template('index.html', utils=HttpUtils)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
|
@ -1,71 +0,0 @@
|
|||
import datetime
|
||||
import re
|
||||
|
||||
from flask import Blueprint, request, redirect, render_template, make_response, abort
|
||||
|
||||
from platypush.backend.http.app import template_folder
|
||||
from platypush.backend.http.utils import HttpUtils
|
||||
from platypush.user import UserManager
|
||||
from platypush.utils import utcnow
|
||||
|
||||
register = Blueprint('register', __name__, template_folder=template_folder)
|
||||
|
||||
# Declare routes list
|
||||
__routes__ = [
|
||||
register,
|
||||
]
|
||||
|
||||
|
||||
@register.route('/register', methods=['GET', 'POST'])
|
||||
def register():
|
||||
"""Registration page"""
|
||||
user_manager = UserManager()
|
||||
redirect_page = request.args.get('redirect')
|
||||
if not redirect_page:
|
||||
redirect_page = request.headers.get('Referer', '/')
|
||||
if re.search('(^https?://[^/]+)?/register[^?#]?', redirect_page):
|
||||
# Prevent redirect loop
|
||||
redirect_page = '/'
|
||||
|
||||
session_token = request.cookies.get('session_token')
|
||||
|
||||
if session_token:
|
||||
user, session = user_manager.authenticate_user_session(session_token)
|
||||
if user:
|
||||
return redirect(redirect_page, 302) # lgtm [py/url-redirection]
|
||||
|
||||
if user_manager.get_user_count() > 0:
|
||||
return redirect(
|
||||
'/login?redirect=' + redirect_page, 302
|
||||
) # lgtm [py/url-redirection]
|
||||
|
||||
if request.form:
|
||||
username = request.form.get('username')
|
||||
password = request.form.get('password')
|
||||
confirm_password = request.form.get('confirm_password')
|
||||
remember = request.form.get('remember')
|
||||
|
||||
if password == confirm_password:
|
||||
user_manager.create_user(username=username, password=password)
|
||||
session = user_manager.create_user_session(
|
||||
username=username,
|
||||
password=password,
|
||||
expires_at=(
|
||||
utcnow() + datetime.timedelta(days=1) if not remember else None
|
||||
),
|
||||
)
|
||||
|
||||
if session:
|
||||
redirect_target = redirect(
|
||||
redirect_page, 302
|
||||
) # lgtm [py/url-redirection]
|
||||
response = make_response(redirect_target)
|
||||
response.set_cookie('session_token', session.session_token)
|
||||
return response
|
||||
else:
|
||||
abort(400, 'Password mismatch')
|
||||
|
||||
return render_template('index.html', utils=HttpUtils)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
|
@ -7,7 +7,7 @@ from typing import Optional
|
|||
from tornado.web import RequestHandler, stream_request_body
|
||||
|
||||
from platypush.backend.http.app.utils import logger
|
||||
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
|
||||
from platypush.backend.http.app.utils.auth import UserAuthStatus, get_auth_status
|
||||
|
||||
from ..mixins import PubSubMixin
|
||||
|
||||
|
@ -29,7 +29,7 @@ class StreamingRoute(RequestHandler, PubSubMixin, ABC):
|
|||
"""
|
||||
if self.auth_required:
|
||||
auth_status = get_auth_status(self.request)
|
||||
if auth_status != AuthStatus.OK:
|
||||
if auth_status != UserAuthStatus.OK:
|
||||
self.send_error(auth_status.value.code, error=auth_status.value.message)
|
||||
return
|
||||
|
||||
|
|
|
@ -2,14 +2,14 @@ import base64
|
|||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import request, redirect, jsonify
|
||||
from flask import request, redirect
|
||||
from flask.wrappers import Response
|
||||
|
||||
from platypush.config import Config
|
||||
from platypush.user import UserManager
|
||||
|
||||
from ..logger import logger
|
||||
from .status import AuthStatus
|
||||
from .status import UserAuthStatus
|
||||
|
||||
user_manager = UserManager()
|
||||
|
||||
|
@ -128,18 +128,18 @@ def authenticate(
|
|||
skip_auth_methods=skip_auth_methods,
|
||||
)
|
||||
|
||||
if auth_status == AuthStatus.OK:
|
||||
if auth_status == UserAuthStatus.OK:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
if json:
|
||||
return jsonify(auth_status.to_dict()), auth_status.value.code
|
||||
return auth_status.to_response()
|
||||
|
||||
if auth_status == AuthStatus.NO_USERS:
|
||||
if auth_status == UserAuthStatus.REGISTRATION_REQUIRED:
|
||||
return redirect(
|
||||
f'/register?redirect={redirect_page or request.url}', 307
|
||||
)
|
||||
|
||||
if auth_status == AuthStatus.UNAUTHORIZED:
|
||||
if auth_status == UserAuthStatus.INVALID_CREDENTIALS:
|
||||
return redirect(f'/login?redirect={redirect_page or request.url}', 307)
|
||||
|
||||
return Response(
|
||||
|
@ -154,7 +154,7 @@ def authenticate(
|
|||
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def get_auth_status(req, skip_auth_methods=None) -> AuthStatus:
|
||||
def get_auth_status(req, skip_auth_methods=None) -> UserAuthStatus:
|
||||
"""
|
||||
Check against the available authentication methods (except those listed in
|
||||
``skip_auth_methods``) if the user is properly authenticated.
|
||||
|
@ -168,29 +168,33 @@ def get_auth_status(req, skip_auth_methods=None) -> AuthStatus:
|
|||
if n_users > 0 and 'http' not in skip_methods:
|
||||
http_auth_ok = authenticate_user_pass(req)
|
||||
if http_auth_ok:
|
||||
return AuthStatus.OK
|
||||
return UserAuthStatus.OK
|
||||
|
||||
# Token-based authentication
|
||||
token_auth_ok = True
|
||||
if 'token' not in skip_methods:
|
||||
token_auth_ok = authenticate_token(req)
|
||||
if token_auth_ok:
|
||||
return AuthStatus.OK
|
||||
return UserAuthStatus.OK
|
||||
|
||||
# Session token based authentication
|
||||
session_auth_ok = True
|
||||
if n_users > 0 and 'session' not in skip_methods:
|
||||
return AuthStatus.OK if authenticate_session(req) else AuthStatus.UNAUTHORIZED
|
||||
return (
|
||||
UserAuthStatus.OK
|
||||
if authenticate_session(req)
|
||||
else UserAuthStatus.INVALID_CREDENTIALS
|
||||
)
|
||||
|
||||
# At least a user should be created before accessing an authenticated resource
|
||||
if n_users == 0 and 'session' not in skip_methods:
|
||||
return AuthStatus.NO_USERS
|
||||
return UserAuthStatus.REGISTRATION_REQUIRED
|
||||
|
||||
if ( # pylint: disable=too-many-boolean-expressions
|
||||
('http' not in skip_methods and http_auth_ok)
|
||||
or ('token' not in skip_methods and token_auth_ok)
|
||||
or ('session' not in skip_methods and session_auth_ok)
|
||||
):
|
||||
return AuthStatus.OK
|
||||
return UserAuthStatus.OK
|
||||
|
||||
return AuthStatus.UNAUTHORIZED
|
||||
return UserAuthStatus.INVALID_CREDENTIALS
|
||||
|
|
|
@ -1,21 +1,67 @@
|
|||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
|
||||
from flask import jsonify
|
||||
|
||||
StatusValue = namedtuple('StatusValue', ['code', 'message'])
|
||||
from platypush.user import AuthenticationStatus
|
||||
|
||||
StatusValue = namedtuple('StatusValue', ['code', 'error', 'message'])
|
||||
|
||||
|
||||
class AuthStatus(Enum):
|
||||
class UserAuthStatus(Enum):
|
||||
"""
|
||||
Models the status of the authentication.
|
||||
"""
|
||||
|
||||
OK = StatusValue(200, 'OK')
|
||||
UNAUTHORIZED = StatusValue(401, 'Unauthorized')
|
||||
NO_USERS = StatusValue(412, 'Please create a user first')
|
||||
OK = StatusValue(200, AuthenticationStatus.OK, 'OK')
|
||||
INVALID_AUTH_TYPE = StatusValue(
|
||||
400, AuthenticationStatus.INVALID_AUTH_TYPE, 'Invalid authentication type'
|
||||
)
|
||||
INVALID_CREDENTIALS = StatusValue(
|
||||
401, AuthenticationStatus.INVALID_CREDENTIALS, 'Invalid credentials'
|
||||
)
|
||||
INVALID_JWT_TOKEN = StatusValue(
|
||||
401, AuthenticationStatus.INVALID_JWT_TOKEN, 'Invalid JWT token'
|
||||
)
|
||||
INVALID_OTP_CODE = StatusValue(
|
||||
401, AuthenticationStatus.INVALID_OTP_CODE, 'Invalid OTP code'
|
||||
)
|
||||
INVALID_METHOD = StatusValue(
|
||||
405, AuthenticationStatus.INVALID_METHOD, 'Invalid method'
|
||||
)
|
||||
MISSING_OTP_CODE = StatusValue(
|
||||
401, AuthenticationStatus.MISSING_OTP_CODE, 'Missing OTP code'
|
||||
)
|
||||
MISSING_PASSWORD = StatusValue(
|
||||
400, AuthenticationStatus.MISSING_PASSWORD, 'Missing password'
|
||||
)
|
||||
MISSING_USERNAME = StatusValue(
|
||||
400, AuthenticationStatus.MISSING_USERNAME, 'Missing username'
|
||||
)
|
||||
PASSWORD_MISMATCH = StatusValue(
|
||||
400, AuthenticationStatus.PASSWORD_MISMATCH, 'Password mismatch'
|
||||
)
|
||||
REGISTRATION_DISABLED = StatusValue(
|
||||
401, AuthenticationStatus.REGISTRATION_DISABLED, 'Registrations are disabled'
|
||||
)
|
||||
REGISTRATION_REQUIRED = StatusValue(
|
||||
412, AuthenticationStatus.REGISTRATION_REQUIRED, 'Please create a user first'
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'code': self.value[0],
|
||||
'message': self.value[1],
|
||||
'error': self.value[1].name,
|
||||
'message': self.value[2],
|
||||
}
|
||||
|
||||
def to_response(self):
|
||||
return jsonify(self.to_dict()), self.value[0]
|
||||
|
||||
@staticmethod
|
||||
def by_status(status: AuthenticationStatus):
|
||||
for auth_status in UserAuthStatus:
|
||||
if auth_status.value[1] == status:
|
||||
return auth_status
|
||||
|
||||
return None
|
||||
|
|
|
@ -5,7 +5,7 @@ from threading import Thread
|
|||
from tornado.ioloop import IOLoop
|
||||
from tornado.websocket import WebSocketHandler
|
||||
|
||||
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
|
||||
from platypush.backend.http.app.utils.auth import UserAuthStatus, get_auth_status
|
||||
|
||||
from ..mixins import MessageType, PubSubMixin
|
||||
|
||||
|
@ -25,7 +25,7 @@ class WSRoute(WebSocketHandler, Thread, PubSubMixin, ABC):
|
|||
|
||||
def open(self, *_, **__):
|
||||
auth_status = get_auth_status(self.request)
|
||||
if auth_status != AuthStatus.OK:
|
||||
if auth_status != UserAuthStatus.OK:
|
||||
self.close(code=1008, reason=auth_status.value.message) # Policy Violation
|
||||
return
|
||||
|
||||
|
|
|
@ -1,20 +1,30 @@
|
|||
<template>
|
||||
<Events ref="events" v-if="hasWebsocket" />
|
||||
<Notifications ref="notifications" />
|
||||
<VoiceAssistant ref="voice-assistant" v-if="hasAssistant" />
|
||||
<Pushbullet ref="pushbullet" v-if="hasPushbullet" />
|
||||
<Ntfy ref="ntfy" v-if="hasNtfy" />
|
||||
<ConfirmDialog ref="pwaDialog" @input="installPWA">
|
||||
Would you like to install this application locally?
|
||||
</ConfirmDialog>
|
||||
<div id="error" v-if="initError">
|
||||
<h1>Initialization error</h1>
|
||||
<p>{{ initError }}</p>
|
||||
</div>
|
||||
|
||||
<DropdownContainer />
|
||||
<router-view />
|
||||
<Loading v-else-if="!initialized" />
|
||||
|
||||
<div id="app-container" v-else>
|
||||
<Events ref="events" v-if="hasWebsocket" />
|
||||
<Notifications ref="notifications" />
|
||||
<VoiceAssistant ref="voice-assistant" v-if="hasAssistant" />
|
||||
<Pushbullet ref="pushbullet" v-if="hasPushbullet" />
|
||||
<Ntfy ref="ntfy" v-if="hasNtfy" />
|
||||
<ConfirmDialog ref="pwaDialog" @input="installPWA">
|
||||
Would you like to install this application locally?
|
||||
</ConfirmDialog>
|
||||
|
||||
<DropdownContainer />
|
||||
<router-view />
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import ConfirmDialog from "@/components/elements/ConfirmDialog";
|
||||
import DropdownContainer from "@/components/elements/DropdownContainer";
|
||||
import Loading from "@/components/Loading";
|
||||
import Notifications from "@/components/Notifications";
|
||||
import Utils from "@/Utils";
|
||||
import Events from "@/Events";
|
||||
|
@ -29,6 +39,7 @@ export default {
|
|||
ConfirmDialog,
|
||||
DropdownContainer,
|
||||
Events,
|
||||
Loading,
|
||||
Notifications,
|
||||
Ntfy,
|
||||
Pushbullet,
|
||||
|
@ -41,6 +52,8 @@ export default {
|
|||
userAuthenticated: false,
|
||||
connected: false,
|
||||
pwaInstallEvent: null,
|
||||
initialized: false,
|
||||
initError: null,
|
||||
}
|
||||
},
|
||||
|
||||
|
@ -84,8 +97,18 @@ export default {
|
|||
}
|
||||
},
|
||||
|
||||
created() {
|
||||
this.initConfig()
|
||||
async created() {
|
||||
try {
|
||||
await this.initConfig()
|
||||
} catch (e) {
|
||||
const code = e?.response?.data?.code
|
||||
if (![401, 403, 412].includes(code)) {
|
||||
this.initError = e
|
||||
console.error('Initialization error', e)
|
||||
}
|
||||
} finally {
|
||||
this.initialized = true
|
||||
}
|
||||
},
|
||||
|
||||
beforeMount() {
|
||||
|
@ -125,6 +148,11 @@ html, body {
|
|||
overflow: auto;
|
||||
}
|
||||
|
||||
#app-container {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
#app {
|
||||
font-family: BlinkMacSystemFont,-apple-system,Avenir,Segoe UI,Roboto,Oxygen,Ubuntu,Cantarell,Fira Sans,Droid Sans,Helvetica Neue,Helvetica,Verdana,Arial,sans-serif;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
|
|
|
@ -8,6 +8,7 @@ $default-bg-6: #e4eae8 !default;
|
|||
$default-bg-7: #e4e4e4 !default;
|
||||
$ok-fg: #17ad17 !default;
|
||||
$error-fg: #ad1717 !default;
|
||||
$error-bg: #ffaaa2 !default;
|
||||
$tile-bg: linear-gradient(90deg, rgba(9,174,128,1) 0%, rgba(71,226,179,1) 120%);
|
||||
$tile-fg: white;
|
||||
$tile-hover-bg: linear-gradient(90deg, rgba(41,216,159,1) 0%, rgba(9,188,138,1) 70%);
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
<template>
|
||||
<div class="login-container">
|
||||
<form class="login" method="POST">
|
||||
<Loading v-if="!initialized" />
|
||||
|
||||
<div class="login-container" v-else>
|
||||
<form class="login" method="POST" @submit="submitForm" v-if="!isAuthenticated">
|
||||
<div class="header">
|
||||
<span class="logo">
|
||||
<img src="/logo.svg" alt="logo" />
|
||||
|
@ -10,24 +12,30 @@
|
|||
|
||||
<div class="row">
|
||||
<label>
|
||||
<input type="text" name="username" placeholder="Username">
|
||||
<input type="text" name="username" :disabled="authenticating" placeholder="Username" ref="username">
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<label>
|
||||
<input type="password" name="password" placeholder="Password">
|
||||
<input type="password" name="password" :disabled="authenticating" placeholder="Password">
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div class="row" v-if="_register">
|
||||
<div class="row" v-if="register">
|
||||
<label>
|
||||
<input type="password" name="confirm_password" placeholder="Confirm password">
|
||||
<input type="password" name="confirm_password" :disabled="authenticating" placeholder="Confirm password">
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div class="row buttons">
|
||||
<input type="submit" class="btn btn-primary" :value="_register ? 'Register' : 'Login'">
|
||||
<button type="submit"
|
||||
class="btn btn-primary"
|
||||
:class="{loading: authenticating}"
|
||||
:disabled="authenticating">
|
||||
<Loading v-if="authenticating" />
|
||||
{{ register ? 'Register' : 'Login' }}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="row pull-right">
|
||||
|
@ -36,16 +44,26 @@
|
|||
Keep me logged in on this device
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div class="auth-error" v-if="authError">
|
||||
{{ authError }}
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import Loading from "@/components/Loading";
|
||||
import Utils from "@/Utils";
|
||||
import axios from 'axios'
|
||||
|
||||
export default {
|
||||
name: "Login",
|
||||
mixins: [Utils],
|
||||
components: {
|
||||
Loading,
|
||||
},
|
||||
|
||||
props: {
|
||||
// Set to true for a registration form, false for a login form
|
||||
register: {
|
||||
|
@ -56,10 +74,86 @@ export default {
|
|||
},
|
||||
|
||||
computed: {
|
||||
_register() {
|
||||
return this.parseBoolean(this.register)
|
||||
redirect() {
|
||||
return this.$route.query.redirect?.length ? this.$route.query.redirect : '/'
|
||||
},
|
||||
}
|
||||
},
|
||||
|
||||
data() {
|
||||
return {
|
||||
authError: null,
|
||||
authenticating: false,
|
||||
isAuthenticated: false,
|
||||
initialized: false,
|
||||
}
|
||||
},
|
||||
|
||||
methods: {
|
||||
async submitForm(e) {
|
||||
e.preventDefault();
|
||||
const form = e.target
|
||||
const data = new FormData(form)
|
||||
const url = `/auth?type=${this.register ? 'register' : 'login'}`
|
||||
|
||||
if (this.register && data.get('password') !== data.get('confirm_password')) {
|
||||
this.authError = "Passwords don't match"
|
||||
return
|
||||
}
|
||||
|
||||
this.authError = null
|
||||
|
||||
try {
|
||||
const authStatus = await axios.post(url, data)
|
||||
const sessionToken = authStatus?.data?.session_token
|
||||
if (sessionToken) {
|
||||
const expiresAt = authStatus.expires_at ? Date.parse(authStatus.expires_at) : null
|
||||
this.isAuthenticated = true
|
||||
this.setCookie('session_token', sessionToken, {
|
||||
expires: expiresAt,
|
||||
})
|
||||
window.location.href = authStatus.redirect || this.redirect
|
||||
} else {
|
||||
this.authError = "Invalid credentials"
|
||||
}
|
||||
} catch (e) {
|
||||
this.authError = e.response.data.message || e.response.data.error
|
||||
|
||||
if (e.response?.status === 401) {
|
||||
this.authError = this.authError || "Invalid credentials"
|
||||
} else {
|
||||
this.authError = this.authError || "An error occurred while processing the request"
|
||||
if (e.response)
|
||||
console.error(e.response.status, e.response.data)
|
||||
else
|
||||
console.error(e)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
async checkAuth() {
|
||||
try {
|
||||
const authStatus = await axios.get('/auth')
|
||||
if (authStatus.data.session_token) {
|
||||
this.isAuthenticated = true
|
||||
window.location.href = authStatus.redirect || this.redirect
|
||||
}
|
||||
} catch (e) {
|
||||
this.isAuthenticated = false
|
||||
} finally {
|
||||
this.initialized = true
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
async created() {
|
||||
await this.checkAuth()
|
||||
},
|
||||
|
||||
async mounted() {
|
||||
this.$nextTick(() => {
|
||||
this.$refs.username?.focus()
|
||||
})
|
||||
},
|
||||
}
|
||||
</script>
|
||||
|
||||
|
@ -116,7 +210,7 @@ form {
|
|||
width: 100%;
|
||||
}
|
||||
|
||||
input[type=submit],
|
||||
[type=submit],
|
||||
input[type=password] {
|
||||
border-radius: 1em;
|
||||
}
|
||||
|
@ -133,10 +227,33 @@ form {
|
|||
.buttons {
|
||||
text-align: center;
|
||||
|
||||
input[type=submit] {
|
||||
[type=submit] {
|
||||
position: relative;
|
||||
width: 6em;
|
||||
height: 2.5em;
|
||||
padding: .5em .75em;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
|
||||
&.loading {
|
||||
background: none;
|
||||
border: none;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.auth-error {
|
||||
background: $error-bg;
|
||||
display: flex;
|
||||
margin: 1em 0 -2em 0;
|
||||
padding: .5em;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
border: $notification-error-border;
|
||||
border-radius: 1em;
|
||||
}
|
||||
}
|
||||
|
||||
a {
|
||||
|
|
|
@ -34,16 +34,17 @@ module.exports = {
|
|||
|
||||
devServer: {
|
||||
proxy: {
|
||||
'^/auth': httpProxy,
|
||||
'^/camera/': httpProxy,
|
||||
'^/execute': httpProxy,
|
||||
'^/file': httpProxy,
|
||||
'^/logo.svg': httpProxy,
|
||||
'^/logout': httpProxy,
|
||||
'^/media/': httpProxy,
|
||||
'^/sound/': httpProxy,
|
||||
'^/ws/events': wsProxy,
|
||||
'^/ws/requests': wsProxy,
|
||||
'^/ws/shell': wsProxy,
|
||||
'^/execute': httpProxy,
|
||||
'^/file': httpProxy,
|
||||
'^/auth': httpProxy,
|
||||
'^/camera/': httpProxy,
|
||||
'^/sound/': httpProxy,
|
||||
'^/media/': httpProxy,
|
||||
'^/logo.svg': httpProxy,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Optional, Tuple, Union
|
||||
|
||||
from platypush.plugins import Plugin, action
|
||||
from platypush.user import UserManager
|
||||
|
@ -66,15 +66,46 @@ class UserPlugin(Plugin):
|
|||
}
|
||||
|
||||
@action
|
||||
def authenticate_user(self, username, password):
|
||||
def authenticate_user(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
code: Optional[str] = None,
|
||||
return_details: bool = False,
|
||||
) -> Union[bool, Tuple[bool, str]]:
|
||||
"""
|
||||
Authenticate a user.
|
||||
|
||||
:return: True if the provided username and password are correct, False
|
||||
otherwise.
|
||||
"""
|
||||
:param username: Username.
|
||||
:param password: Password.
|
||||
:param code: Optional 2FA code, if 2FA is enabled for the user.
|
||||
:param return_error_details: If True then return the error details in
|
||||
case of authentication failure.
|
||||
:return: If ``return_details`` is False (default), the action returns
|
||||
True if the provided credentials are valid, False otherwise.
|
||||
If ``return_details`` is True then the action returns a tuple
|
||||
(authenticated, error_details) where ``authenticated`` is True if
|
||||
the provided credentials are valid, False otherwise, and
|
||||
``error_details`` is a string containing the error details in case
|
||||
of authentication failure. Supported error details are:
|
||||
|
||||
return bool(self.user_manager.authenticate_user(username, password))
|
||||
- ``invalid_credentials``: Invalid username or password.
|
||||
- ``invalid_otp_code``: Invalid 2FA code.
|
||||
- ``missing_otp_code``: Username/password are correct, but a 2FA
|
||||
code is required for the user.
|
||||
|
||||
"""
|
||||
response = self.user_manager.authenticate_user(
|
||||
username, password, code=code, return_error=return_details
|
||||
)
|
||||
|
||||
if return_details:
|
||||
assert (
|
||||
isinstance(response, tuple) and len(response) == 2
|
||||
), 'Invalid response from authenticate_user'
|
||||
return response[0], response[1].value
|
||||
|
||||
return response
|
||||
|
||||
@action
|
||||
def update_password(self, username, old_password, new_password):
|
||||
|
@ -111,7 +142,7 @@ class UserPlugin(Plugin):
|
|||
return None, "No such user: {}".format(username)
|
||||
|
||||
@action
|
||||
def create_session(self, username, password, expires_at=None):
|
||||
def create_session(self, username, password, code=None, expires_at=None):
|
||||
"""
|
||||
Create a user session.
|
||||
|
||||
|
@ -130,7 +161,7 @@ class UserPlugin(Plugin):
|
|||
"""
|
||||
|
||||
session = self.user_manager.create_user_session(
|
||||
username=username, password=password, expires_at=expires_at
|
||||
username=username, password=password, code=code, expires_at=expires_at
|
||||
)
|
||||
|
||||
if not session:
|
||||
|
@ -140,9 +171,9 @@ class UserPlugin(Plugin):
|
|||
'session_token': session.session_token,
|
||||
'user_id': session.user_id,
|
||||
'created_at': session.created_at.isoformat(),
|
||||
'expires_at': session.expires_at.isoformat()
|
||||
if session.expires_at
|
||||
else None,
|
||||
'expires_at': (
|
||||
session.expires_at.isoformat() if session.expires_at else None # type: ignore
|
||||
),
|
||||
}
|
||||
|
||||
@action
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import base64
|
||||
import datetime
|
||||
import enum
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Optional, Dict
|
||||
from typing import List, Optional, Dict, Tuple, Union
|
||||
|
||||
import rsa
|
||||
|
||||
|
@ -13,12 +14,32 @@ from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
|||
from sqlalchemy.orm import make_transient
|
||||
|
||||
from platypush.common.db import Base
|
||||
from platypush.config import Config
|
||||
from platypush.context import get_plugin
|
||||
from platypush.exceptions.user import (
|
||||
InvalidJWTTokenException,
|
||||
InvalidCredentialsException,
|
||||
)
|
||||
from platypush.utils import get_or_generate_jwt_rsa_key_pair, utcnow
|
||||
from platypush.utils import get_or_generate_stored_rsa_key_pair, utcnow
|
||||
|
||||
|
||||
class AuthenticationStatus(enum.Enum):
|
||||
"""
|
||||
Enum for authentication errors.
|
||||
"""
|
||||
|
||||
OK = ''
|
||||
INVALID_AUTH_TYPE = 'invalid_auth_type'
|
||||
INVALID_CREDENTIALS = 'invalid_credentials'
|
||||
INVALID_METHOD = 'invalid_method'
|
||||
INVALID_JWT_TOKEN = 'invalid_jwt_token'
|
||||
INVALID_OTP_CODE = 'invalid_otp_code'
|
||||
MISSING_OTP_CODE = 'missing_otp_code'
|
||||
MISSING_PASSWORD = 'missing_password'
|
||||
MISSING_USERNAME = 'missing_username'
|
||||
PASSWORD_MISMATCH = 'password_mismatch'
|
||||
REGISTRATION_DISABLED = 'registration_disabled'
|
||||
REGISTRATION_REQUIRED = 'registration_required'
|
||||
|
||||
|
||||
class UserManager:
|
||||
|
@ -26,6 +47,14 @@ class UserManager:
|
|||
Main class for managing platform users
|
||||
"""
|
||||
|
||||
_otp_workdir = os.path.join(Config.get_workdir(), 'otp')
|
||||
_otp_keyfile = os.path.join(_otp_workdir, 'key')
|
||||
_otp_keyfile_pub = f'{_otp_keyfile}.pub'
|
||||
|
||||
_jwt_workdir = os.path.join(Config.get_workdir(), 'jwt')
|
||||
_jwt_keyfile = os.path.join(_jwt_workdir, 'id_rsa')
|
||||
_jwt_keyfile_pub = f'{_jwt_keyfile}.pub'
|
||||
|
||||
def __init__(self):
|
||||
db_plugin = get_plugin('db')
|
||||
assert db_plugin, 'Database plugin not configured'
|
||||
|
@ -41,6 +70,44 @@ class UserManager:
|
|||
def _get_session(self, *args, **kwargs):
|
||||
return self.db.get_session(self.db.get_engine(), *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _get_jwt_rsa_key_pair(cls):
|
||||
"""
|
||||
Get or generate the JWT RSA key pair.
|
||||
"""
|
||||
return get_or_generate_stored_rsa_key_pair(cls._jwt_keyfile, size=2048)
|
||||
|
||||
@classmethod
|
||||
def _get_or_generate_otp_rsa_key_pair(cls):
|
||||
"""
|
||||
Get or generate the OTP RSA key pair.
|
||||
"""
|
||||
return get_or_generate_stored_rsa_key_pair(cls._otp_keyfile, size=4096)
|
||||
|
||||
@staticmethod
|
||||
def _encrypt(data: Union[str, bytes, dict, list, tuple], key: rsa.PublicKey) -> str:
|
||||
"""
|
||||
Encrypt the data using the given RSA public key.
|
||||
"""
|
||||
if isinstance(data, tuple):
|
||||
data = list(data)
|
||||
if isinstance(data, (dict, list)):
|
||||
data = json.dumps(data, sort_keys=True, indent=None)
|
||||
if isinstance(data, str):
|
||||
data = data.encode('ascii')
|
||||
|
||||
return base64.b64encode(rsa.encrypt(data, key)).decode()
|
||||
|
||||
@staticmethod
|
||||
def _decrypt(data: Union[str, bytes], key: rsa.PrivateKey) -> str:
|
||||
"""
|
||||
Decrypt the data using the given RSA private key.
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode('ascii')
|
||||
|
||||
return rsa.decrypt(base64.b64decode(data), key).decode()
|
||||
|
||||
def get_user(self, username):
|
||||
with self._get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
|
@ -88,9 +155,9 @@ class UserManager:
|
|||
|
||||
return self._mask_password(user)
|
||||
|
||||
def update_password(self, username, old_password, new_password):
|
||||
def update_password(self, username, old_password, new_password, code=None):
|
||||
with self._get_session(locked=True) as session:
|
||||
if not self._authenticate_user(session, username, old_password):
|
||||
if not self._authenticate_user(session, username, old_password, code=code):
|
||||
return False
|
||||
|
||||
user = self._get_user(session, username)
|
||||
|
@ -103,12 +170,22 @@ class UserManager:
|
|||
session.commit()
|
||||
return True
|
||||
|
||||
def authenticate_user(self, username, password):
|
||||
def authenticate_user(self, username, password, code=None, return_error=False):
|
||||
with self._get_session() as session:
|
||||
return self._authenticate_user(session, username, password)
|
||||
return self._authenticate_user(
|
||||
session, username, password, code=code, return_error=return_error
|
||||
)
|
||||
|
||||
def authenticate_user_session(self, session_token):
|
||||
def authenticate_user_session(self, session_token, with_error=False):
|
||||
with self._get_session() as session:
|
||||
users_count = session.query(User).count()
|
||||
if not users_count:
|
||||
return (
|
||||
(None, None, AuthenticationStatus.REGISTRATION_REQUIRED)
|
||||
if with_error
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
user_session = (
|
||||
session.query(UserSession)
|
||||
.filter_by(session_token=session_token)
|
||||
|
@ -122,10 +199,21 @@ class UserManager:
|
|||
)
|
||||
|
||||
if not user_session or (expires_at and expires_at < utcnow()):
|
||||
return None, None
|
||||
return (
|
||||
(None, None, AuthenticationStatus.INVALID_CREDENTIALS)
|
||||
if with_error
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
user = session.query(User).filter_by(user_id=user_session.user_id).first()
|
||||
return self._mask_password(user), user_session
|
||||
return (
|
||||
(self._mask_password(user), user_session, AuthenticationStatus.OK)
|
||||
if with_error
|
||||
else (
|
||||
self._mask_password(user),
|
||||
user_session,
|
||||
)
|
||||
)
|
||||
|
||||
def delete_user(self, username):
|
||||
with self._get_session(locked=True) as session:
|
||||
|
@ -158,11 +246,25 @@ class UserManager:
|
|||
session.commit()
|
||||
return True
|
||||
|
||||
def create_user_session(self, username, password, expires_at=None):
|
||||
def create_user_session(
|
||||
self,
|
||||
username,
|
||||
password,
|
||||
code=None,
|
||||
expires_at=None,
|
||||
error_on_invalid_credentials=False,
|
||||
):
|
||||
with self._get_session(locked=True) as session:
|
||||
user = self._authenticate_user(session, username, password)
|
||||
user, status = self._authenticate_user( # type: ignore
|
||||
session,
|
||||
username,
|
||||
password,
|
||||
code=code,
|
||||
return_error=error_on_invalid_credentials,
|
||||
)
|
||||
|
||||
if not user:
|
||||
return None
|
||||
return None if not error_on_invalid_credentials else (None, status)
|
||||
|
||||
if expires_at:
|
||||
if isinstance(expires_at, (int, float)):
|
||||
|
@ -180,7 +282,53 @@ class UserManager:
|
|||
|
||||
session.add(user_session)
|
||||
session.commit()
|
||||
return user_session
|
||||
return user_session, (
|
||||
AuthenticationStatus.OK if not error_on_invalid_credentials else status
|
||||
)
|
||||
|
||||
def create_otp_secret(
|
||||
self, username: str, expires_at: Optional[datetime.datetime] = None
|
||||
):
|
||||
pubkey, _ = self._get_or_generate_otp_rsa_key_pair()
|
||||
|
||||
# Generate a new OTP secret and encrypt it with the OTP RSA key pair
|
||||
otp_secret = "".join(
|
||||
random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(32)
|
||||
)
|
||||
|
||||
encrypted_secret = self._encrypt(otp_secret, pubkey)
|
||||
|
||||
with self._get_session(locked=True) as session:
|
||||
user = self._get_user(session, username)
|
||||
assert user, f'No such user: {username}'
|
||||
|
||||
# Create a new OTP secret
|
||||
user_otp = UserOtp(
|
||||
user_id=user.user_id,
|
||||
otp_secret=encrypted_secret,
|
||||
created_at=utcnow(),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
# Remove any existing OTP secret and replace it with the new one
|
||||
session.query(UserOtp).filter_by(user_id=user.user_id).delete()
|
||||
session.add(user_otp)
|
||||
session.commit()
|
||||
|
||||
return user_otp
|
||||
|
||||
def get_otp_secret(self, username: str) -> Optional[str]:
|
||||
with self._get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user_otp = session.query(UserOtp).filter_by(user_id=user.user_id).first()
|
||||
if not user_otp:
|
||||
return None
|
||||
|
||||
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
|
||||
return self._decrypt(user_otp.otp_secret, priv_key)
|
||||
|
||||
@staticmethod
|
||||
def _get_user(session, username):
|
||||
|
@ -268,20 +416,17 @@ class UserManager:
|
|||
if not user:
|
||||
raise InvalidCredentialsException()
|
||||
|
||||
pub_key, _ = get_or_generate_jwt_rsa_key_pair()
|
||||
payload = json.dumps(
|
||||
pub_key, _ = self._get_jwt_rsa_key_pair()
|
||||
return self._encrypt(
|
||||
{
|
||||
'username': username,
|
||||
'password': password,
|
||||
'created_at': datetime.datetime.now().timestamp(),
|
||||
'expires_at': expires_at.timestamp() if expires_at else None,
|
||||
},
|
||||
sort_keys=True,
|
||||
indent=None,
|
||||
pub_key,
|
||||
)
|
||||
|
||||
return base64.b64encode(rsa.encrypt(payload.encode('ascii'), pub_key)).decode()
|
||||
|
||||
def validate_jwt_token(self, token: str) -> Dict[str, str]:
|
||||
"""
|
||||
Validate a JWT token.
|
||||
|
@ -299,14 +444,10 @@ class UserManager:
|
|||
|
||||
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
|
||||
"""
|
||||
_, priv_key = get_or_generate_jwt_rsa_key_pair()
|
||||
_, priv_key = self._get_jwt_rsa_key_pair()
|
||||
|
||||
try:
|
||||
payload = json.loads(
|
||||
rsa.decrypt(base64.b64decode(token.encode('ascii')), priv_key).decode(
|
||||
'ascii'
|
||||
)
|
||||
)
|
||||
payload = json.loads(self._decrypt(token, priv_key))
|
||||
except (TypeError, ValueError) as e:
|
||||
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e
|
||||
|
||||
|
@ -323,23 +464,160 @@ class UserManager:
|
|||
|
||||
return payload
|
||||
|
||||
def _authenticate_user(self, session, username, password):
|
||||
def _authenticate_user(
|
||||
self,
|
||||
session,
|
||||
username: str,
|
||||
password: str,
|
||||
code: Optional[str] = None,
|
||||
return_error: bool = False,
|
||||
) -> Union[Optional['User'], Tuple[Optional['User'], 'AuthenticationStatus']]:
|
||||
"""
|
||||
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.
|
||||
:return: :class:`platypush.user.User` instance if the user exists and
|
||||
the password is valid, ``None`` otherwise.
|
||||
"""
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# The user does not exist
|
||||
if not user:
|
||||
return (
|
||||
None
|
||||
if not return_error
|
||||
else (None, AuthenticationStatus.INVALID_CREDENTIALS)
|
||||
)
|
||||
|
||||
# The password is not correct
|
||||
if not self._check_password(
|
||||
password,
|
||||
user.password,
|
||||
bytes.fromhex(user.password_salt) if user.password_salt else None,
|
||||
user.hmac_iterations,
|
||||
):
|
||||
return None
|
||||
return (
|
||||
None
|
||||
if not return_error
|
||||
else (None, AuthenticationStatus.INVALID_CREDENTIALS)
|
||||
)
|
||||
|
||||
return user
|
||||
otp_secret = self.get_otp_secret(username)
|
||||
|
||||
# The user doesn't have 2FA enabled and the password is correct:
|
||||
# authentication successful
|
||||
if not otp_secret:
|
||||
return user if not return_error else (user, AuthenticationStatus.OK)
|
||||
|
||||
# The user has 2FA enabled but the code is missing
|
||||
if not code:
|
||||
return (
|
||||
None
|
||||
if not return_error
|
||||
else (None, AuthenticationStatus.MISSING_OTP_CODE)
|
||||
)
|
||||
|
||||
if self.validate_otp_code(username, code):
|
||||
return user if not return_error else (user, AuthenticationStatus.OK)
|
||||
|
||||
if not self.validate_backup_code(username, code):
|
||||
return (
|
||||
None
|
||||
if not return_error
|
||||
else (None, AuthenticationStatus.INVALID_OTP_CODE)
|
||||
)
|
||||
|
||||
return user if not return_error else (user, AuthenticationStatus.OK)
|
||||
|
||||
def refresh_user_backup_codes(self, username: str):
|
||||
"""
|
||||
Refresh the backup codes for a user with 2FA enabled.
|
||||
"""
|
||||
with self._get_session(locked=True) as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
session.query(UserBackupCode).filter_by(user_id=user.user_id).delete()
|
||||
pub_key, _ = self._get_or_generate_otp_rsa_key_pair()
|
||||
|
||||
for _ in range(10):
|
||||
backup_code = "".join(
|
||||
random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(16)
|
||||
)
|
||||
|
||||
user_backup_code = UserBackupCode(
|
||||
user_id=user.user_id,
|
||||
code=self._encrypt(backup_code, pub_key),
|
||||
created_at=utcnow(),
|
||||
expires_at=utcnow() + datetime.timedelta(days=30),
|
||||
)
|
||||
|
||||
session.add(user_backup_code)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
def get_user_backup_codes(self, username: str) -> List['UserBackupCode']:
|
||||
with self._get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return []
|
||||
|
||||
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
|
||||
codes = session.query(UserBackupCode).filter_by(user_id=user.user_id).all()
|
||||
|
||||
for code in codes:
|
||||
code.code = self._decrypt(code.code, priv_key)
|
||||
|
||||
return codes
|
||||
|
||||
def validate_backup_code(self, username: str, code: str) -> bool:
|
||||
with self._get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
pub_key, _ = self._get_or_generate_otp_rsa_key_pair()
|
||||
user_backup_code = (
|
||||
session.query(UserBackupCode)
|
||||
.filter_by(user_id=user.user_id, code=self._encrypt(code, pub_key))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_backup_code:
|
||||
return False
|
||||
|
||||
session.delete(user_backup_code)
|
||||
session.commit()
|
||||
|
||||
return True
|
||||
|
||||
def validate_otp_code(self, username: str, code: str) -> bool:
|
||||
otp = get_plugin('otp')
|
||||
assert otp
|
||||
|
||||
with self._get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
otp_secret = self.get_otp_secret(username)
|
||||
if not otp_secret:
|
||||
return False
|
||||
|
||||
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
|
||||
otp_secret = self._decrypt(otp_secret, priv_key)
|
||||
|
||||
return otp.verify_time_otp(otp=code, secret=otp_secret)
|
||||
|
||||
def disable_mfa(self, username: str):
|
||||
with self._get_session(locked=True) as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
session.query(UserOtp).filter_by(user_id=user.user_id).delete()
|
||||
session.query(UserBackupCode).filter_by(user_id=user.user_id).delete()
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
|
||||
class User(Base):
|
||||
|
@ -370,4 +648,31 @@ class UserSession(Base):
|
|||
expires_at = Column(DateTime)
|
||||
|
||||
|
||||
class UserOtp(Base):
|
||||
"""
|
||||
Models the UserOtp table, which contains the OTP secrets for each user.
|
||||
"""
|
||||
|
||||
__tablename__ = 'user_otp'
|
||||
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
|
||||
otp_secret = Column(String, nullable=False, unique=True)
|
||||
created_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
|
||||
|
||||
class UserBackupCode(Base):
|
||||
"""
|
||||
Models the UserBackupCode table, which contains the backup codes for each
|
||||
user with 2FA enabled.
|
||||
"""
|
||||
|
||||
__tablename__ = 'user_backup_code'
|
||||
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
|
||||
code = Column(String, nullable=False, unique=True)
|
||||
created_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -532,7 +532,12 @@ def generate_rsa_key_pair(
|
|||
private_key_str = priv_key.save_pkcs1('PEM').decode()
|
||||
|
||||
if key_file:
|
||||
pathlib.Path(os.path.dirname(os.path.expanduser(key_file))).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
logger.info('Saving private key to %s', key_file)
|
||||
|
||||
with open(os.path.expanduser(key_file), 'w') as f1, open(
|
||||
os.path.expanduser(key_file) + '.pub', 'w'
|
||||
) as f2:
|
||||
|
@ -543,14 +548,20 @@ def generate_rsa_key_pair(
|
|||
return pub_key, priv_key
|
||||
|
||||
|
||||
def get_or_generate_jwt_rsa_key_pair():
|
||||
def get_or_generate_stored_rsa_key_pair(
|
||||
keyfile: str, size: int = 2048
|
||||
) -> Tuple[PublicKey, PrivateKey]:
|
||||
"""
|
||||
Get or generate a JWT RSA key pair.
|
||||
"""
|
||||
from platypush.config import Config
|
||||
Get or generate an RSA key pair and store it in the given key file.
|
||||
|
||||
key_dir = os.path.join(Config.get_workdir(), 'jwt')
|
||||
priv_key_file = os.path.join(key_dir, 'id_rsa')
|
||||
The private key will be stored in the given file, while the public key will
|
||||
be stored in ``<keyfile>.pub``.
|
||||
|
||||
:param keyfile: Path to the key file.
|
||||
:param size: Key size in bits (default: 2048).
|
||||
"""
|
||||
keydir = os.path.dirname(os.path.expanduser(keyfile))
|
||||
priv_key_file = os.path.join(keydir, os.path.basename(keyfile))
|
||||
pub_key_file = priv_key_file + '.pub'
|
||||
|
||||
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
|
||||
|
@ -560,8 +571,15 @@ def get_or_generate_jwt_rsa_key_pair():
|
|||
PrivateKey.load_pkcs1(f2.read().encode()),
|
||||
)
|
||||
|
||||
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
||||
return generate_rsa_key_pair(priv_key_file, size=2048)
|
||||
pub_key, priv_key = generate_rsa_key_pair(priv_key_file, size=size)
|
||||
pathlib.Path(keydir).mkdir(parents=True, exist_ok=True, mode=0o755)
|
||||
|
||||
with open(pub_key_file, 'w') as f1, open(priv_key_file, 'w') as f2:
|
||||
f1.write(pub_key.save_pkcs1('PEM').decode())
|
||||
f2.write(priv_key.save_pkcs1('PEM').decode())
|
||||
os.chmod(priv_key_file, 0o600)
|
||||
|
||||
return pub_key, priv_key
|
||||
|
||||
|
||||
def get_enabled_plugins() -> dict:
|
||||
|
|
Loading…
Reference in a new issue