[core] Refactored Web login/registration layer.

Instead of having a single Flask-provided endpoint, the UI should
initialize its own Vue component and manage the authentication
asynchronously over API.

This is especially a requirement for the implementation of 2FA.

The following routes have also been merged/refactored:

- `POST /register` -> `POST /auth?type=register`
- `POST /login` -> `POST /auth?type=login`
- `POST /auth` -> `POST /auth?type=jwt`
This commit is contained in:
Fabio Manganiello 2024-07-23 02:05:29 +02:00
parent 8904e40f9f
commit ee27b2c4c6
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
14 changed files with 839 additions and 262 deletions

View file

@ -4,8 +4,10 @@ import logging
from flask import Blueprint, request, abort, jsonify
from platypush.backend.http.app.utils.auth import UserAuthStatus
from platypush.exceptions.user import UserException
from platypush.user import UserManager
from platypush.utils import utcnow
auth = Blueprint('auth', __name__)
log = logging.getLogger(__name__)
@ -16,39 +18,24 @@ __routes__ = [
]
@auth.route('/auth', methods=['POST'])
def auth_endpoint():
"""
Authentication endpoint. It validates the user credentials provided over a JSON payload with the following
structure:
.. code-block:: json
def _dump_session(session, redirect_page='/'):
return jsonify(
{
"username": "USERNAME",
"password": "PASSWORD",
"expiry_days": "The generated token should be valid for these many days"
'status': 'ok',
'user_id': session.user_id,
'session_token': session.session_token,
'expires_at': session.expires_at,
'redirect': redirect_page,
}
)
``expiry_days`` is optional, and if omitted or set to zero the token will be valid indefinitely.
Upon successful validation, a new JWT token will be generated using the service's self-generated RSA key-pair and it
will be returned to the user. The token can then be used to authenticate API calls to ``/execute`` by setting the
``Authorization: Bearer <TOKEN_HERE>`` header upon HTTP calls.
:return: Return structure:
.. code-block:: json
{
"token": "<generated token here>"
}
"""
def _jwt_auth():
try:
payload = json.loads(request.get_data(as_text=True))
username, password = payload['username'], payload['password']
except Exception as e:
log.warning('Invalid payload passed to the auth endpoint: ' + str(e))
except Exception:
log.warning('Invalid payload passed to the auth endpoint')
abort(400)
expiry_days = payload.get('expiry_days')
@ -59,8 +46,174 @@ def auth_endpoint():
user_manager = UserManager()
try:
return jsonify({
'token': user_manager.generate_jwt_token(username=username, password=password, expires_at=expires_at),
})
return jsonify(
{
'token': user_manager.generate_jwt_token(
username=username, password=password, expires_at=expires_at
),
}
)
except UserException as e:
abort(401, str(e))
def _session_auth():
user_manager = UserManager()
session_token = request.cookies.get('session_token')
redirect_page = request.args.get('redirect') or '/'
if session_token:
user, session = user_manager.authenticate_user_session(session_token) # type: ignore
if user and session:
return _dump_session(session, redirect_page)
if request.form:
username = request.form.get('username')
password = request.form.get('password')
code = request.form.get('code')
remember = request.form.get('remember')
expires = utcnow() + datetime.timedelta(days=365) if remember else None
session, status = user_manager.create_user_session( # type: ignore
username=username,
password=password,
code=code,
expires_at=expires,
error_on_invalid_credentials=True,
)
if session:
return _dump_session(session, redirect_page)
if status:
return status.to_response() # type: ignore
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
def _register_route():
"""Registration endpoint"""
user_manager = UserManager()
session_token = request.cookies.get('session_token')
redirect_page = request.args.get('redirect') or '/'
if session_token:
user, session = user_manager.authenticate_user_session(session_token) # type: ignore
if user and session:
return _dump_session(session, redirect_page)
if user_manager.get_user_count() > 0:
return UserAuthStatus.REGISTRATION_DISABLED.to_response()
if not request.form:
return UserAuthStatus.MISSING_USERNAME.to_response()
username = request.form.get('username')
password = request.form.get('password')
confirm_password = request.form.get('confirm_password')
remember = request.form.get('remember')
if not username:
return UserAuthStatus.MISSING_USERNAME.to_response()
if not password:
return UserAuthStatus.MISSING_PASSWORD.to_response()
if password != confirm_password:
return UserAuthStatus.PASSWORD_MISMATCH.to_response()
user_manager.create_user(username=username, password=password)
session, status = user_manager.create_user_session( # type: ignore
username=username,
password=password,
expires_at=(utcnow() + datetime.timedelta(days=365) if remember else None),
error_on_invalid_credentials=True,
)
if session:
return _dump_session(session, redirect_page)
if status:
return status.to_response() # type: ignore
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
def _auth_get():
"""
Get the current authentication status of the user session.
"""
user_manager = UserManager()
session_token = request.cookies.get('session_token')
redirect_page = request.args.get('redirect') or '/'
user, session, status = user_manager.authenticate_user_session( # type: ignore
session_token, with_error=True
)
if user and session:
return _dump_session(session, redirect_page)
if status:
return UserAuthStatus.by_status(status).to_response() # type: ignore
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
def _auth_post():
"""
Authenticate the user session.
"""
auth_type = request.args.get('type') or 'jwt'
if auth_type == 'jwt':
return _jwt_auth()
if auth_type == 'register':
return _register_route()
if auth_type == 'login':
return _session_auth()
return UserAuthStatus.INVALID_AUTH_TYPE.to_response()
@auth.route('/auth', methods=['GET', 'POST'])
def auth_endpoint():
"""
Authentication endpoint. It validates the user credentials provided over a
JSON payload with the following structure:
.. code-block:: json
{
"username": "USERNAME",
"password": "PASSWORD",
"code": "2FA_CODE",
"expiry_days": "The generated token should be valid for these many days"
}
``expiry_days`` is optional, and if omitted or set to zero the token will
be valid indefinitely.
Upon successful validation, a new JWT token will be generated using the
service's self-generated RSA key-pair and it will be returned to the user.
The token can then be used to authenticate API calls to ``/execute`` by
setting the ``Authorization: Bearer <TOKEN_HERE>`` header upon HTTP calls.
:return: Return structure:
.. code-block:: json
{
"token": "<generated token here>"
}
"""
if request.method == 'GET':
return _auth_get()
if request.method == 'POST':
return _auth_post()
return UserAuthStatus.INVALID_METHOD.to_response()
# from flask import Blueprint, request, redirect, render_template, make_response, abort
# vim:sw=4:ts=4:et:

View file

@ -1,56 +0,0 @@
import datetime
import re
from flask import Blueprint, request, redirect, render_template, make_response
from platypush.backend.http.app import template_folder
from platypush.backend.http.utils import HttpUtils
from platypush.user import UserManager
from platypush.utils import utcnow
login = Blueprint('login', __name__, template_folder=template_folder)
# Declare routes list
__routes__ = [
login,
]
@login.route('/login', methods=['GET', 'POST'])
def login():
"""Login page"""
user_manager = UserManager()
session_token = request.cookies.get('session_token')
redirect_page = request.args.get('redirect')
if not redirect_page:
redirect_page = request.headers.get('Referer', '/')
if re.search('(^https?://[^/]+)?/login[^?#]?', redirect_page):
# Prevent redirect loop
redirect_page = '/'
if session_token:
user, session = user_manager.authenticate_user_session(session_token)
if user:
return redirect(redirect_page, 302) # lgtm [py/url-redirection]
if request.form:
username = request.form.get('username')
password = request.form.get('password')
remember = request.form.get('remember')
expires = utcnow() + datetime.timedelta(days=365) if remember else None
session = user_manager.create_user_session(
username=username, password=password, expires_at=expires
)
if session:
redirect_target = redirect(redirect_page, 302) # lgtm [py/url-redirection]
response = make_response(redirect_target)
response.set_cookie('session_token', session.session_token, expires=expires)
return response
return render_template('index.html', utils=HttpUtils)
# vim:sw=4:ts=4:et:

View file

@ -1,71 +0,0 @@
import datetime
import re
from flask import Blueprint, request, redirect, render_template, make_response, abort
from platypush.backend.http.app import template_folder
from platypush.backend.http.utils import HttpUtils
from platypush.user import UserManager
from platypush.utils import utcnow
register = Blueprint('register', __name__, template_folder=template_folder)
# Declare routes list
__routes__ = [
register,
]
@register.route('/register', methods=['GET', 'POST'])
def register():
"""Registration page"""
user_manager = UserManager()
redirect_page = request.args.get('redirect')
if not redirect_page:
redirect_page = request.headers.get('Referer', '/')
if re.search('(^https?://[^/]+)?/register[^?#]?', redirect_page):
# Prevent redirect loop
redirect_page = '/'
session_token = request.cookies.get('session_token')
if session_token:
user, session = user_manager.authenticate_user_session(session_token)
if user:
return redirect(redirect_page, 302) # lgtm [py/url-redirection]
if user_manager.get_user_count() > 0:
return redirect(
'/login?redirect=' + redirect_page, 302
) # lgtm [py/url-redirection]
if request.form:
username = request.form.get('username')
password = request.form.get('password')
confirm_password = request.form.get('confirm_password')
remember = request.form.get('remember')
if password == confirm_password:
user_manager.create_user(username=username, password=password)
session = user_manager.create_user_session(
username=username,
password=password,
expires_at=(
utcnow() + datetime.timedelta(days=1) if not remember else None
),
)
if session:
redirect_target = redirect(
redirect_page, 302
) # lgtm [py/url-redirection]
response = make_response(redirect_target)
response.set_cookie('session_token', session.session_token)
return response
else:
abort(400, 'Password mismatch')
return render_template('index.html', utils=HttpUtils)
# vim:sw=4:ts=4:et:

View file

@ -7,7 +7,7 @@ from typing import Optional
from tornado.web import RequestHandler, stream_request_body
from platypush.backend.http.app.utils import logger
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
from platypush.backend.http.app.utils.auth import UserAuthStatus, get_auth_status
from ..mixins import PubSubMixin
@ -29,7 +29,7 @@ class StreamingRoute(RequestHandler, PubSubMixin, ABC):
"""
if self.auth_required:
auth_status = get_auth_status(self.request)
if auth_status != AuthStatus.OK:
if auth_status != UserAuthStatus.OK:
self.send_error(auth_status.value.code, error=auth_status.value.message)
return

View file

@ -2,14 +2,14 @@ import base64
from functools import wraps
from typing import Optional
from flask import request, redirect, jsonify
from flask import request, redirect
from flask.wrappers import Response
from platypush.config import Config
from platypush.user import UserManager
from ..logger import logger
from .status import AuthStatus
from .status import UserAuthStatus
user_manager = UserManager()
@ -128,18 +128,18 @@ def authenticate(
skip_auth_methods=skip_auth_methods,
)
if auth_status == AuthStatus.OK:
if auth_status == UserAuthStatus.OK:
return f(*args, **kwargs)
if json:
return jsonify(auth_status.to_dict()), auth_status.value.code
return auth_status.to_response()
if auth_status == AuthStatus.NO_USERS:
if auth_status == UserAuthStatus.REGISTRATION_REQUIRED:
return redirect(
f'/register?redirect={redirect_page or request.url}', 307
)
if auth_status == AuthStatus.UNAUTHORIZED:
if auth_status == UserAuthStatus.INVALID_CREDENTIALS:
return redirect(f'/login?redirect={redirect_page or request.url}', 307)
return Response(
@ -154,7 +154,7 @@ def authenticate(
# pylint: disable=too-many-return-statements
def get_auth_status(req, skip_auth_methods=None) -> AuthStatus:
def get_auth_status(req, skip_auth_methods=None) -> UserAuthStatus:
"""
Check against the available authentication methods (except those listed in
``skip_auth_methods``) if the user is properly authenticated.
@ -168,29 +168,33 @@ def get_auth_status(req, skip_auth_methods=None) -> AuthStatus:
if n_users > 0 and 'http' not in skip_methods:
http_auth_ok = authenticate_user_pass(req)
if http_auth_ok:
return AuthStatus.OK
return UserAuthStatus.OK
# Token-based authentication
token_auth_ok = True
if 'token' not in skip_methods:
token_auth_ok = authenticate_token(req)
if token_auth_ok:
return AuthStatus.OK
return UserAuthStatus.OK
# Session token based authentication
session_auth_ok = True
if n_users > 0 and 'session' not in skip_methods:
return AuthStatus.OK if authenticate_session(req) else AuthStatus.UNAUTHORIZED
return (
UserAuthStatus.OK
if authenticate_session(req)
else UserAuthStatus.INVALID_CREDENTIALS
)
# At least a user should be created before accessing an authenticated resource
if n_users == 0 and 'session' not in skip_methods:
return AuthStatus.NO_USERS
return UserAuthStatus.REGISTRATION_REQUIRED
if ( # pylint: disable=too-many-boolean-expressions
('http' not in skip_methods and http_auth_ok)
or ('token' not in skip_methods and token_auth_ok)
or ('session' not in skip_methods and session_auth_ok)
):
return AuthStatus.OK
return UserAuthStatus.OK
return AuthStatus.UNAUTHORIZED
return UserAuthStatus.INVALID_CREDENTIALS

View file

@ -1,21 +1,67 @@
from collections import namedtuple
from enum import Enum
from flask import jsonify
StatusValue = namedtuple('StatusValue', ['code', 'message'])
from platypush.user import AuthenticationStatus
StatusValue = namedtuple('StatusValue', ['code', 'error', 'message'])
class AuthStatus(Enum):
class UserAuthStatus(Enum):
"""
Models the status of the authentication.
"""
OK = StatusValue(200, 'OK')
UNAUTHORIZED = StatusValue(401, 'Unauthorized')
NO_USERS = StatusValue(412, 'Please create a user first')
OK = StatusValue(200, AuthenticationStatus.OK, 'OK')
INVALID_AUTH_TYPE = StatusValue(
400, AuthenticationStatus.INVALID_AUTH_TYPE, 'Invalid authentication type'
)
INVALID_CREDENTIALS = StatusValue(
401, AuthenticationStatus.INVALID_CREDENTIALS, 'Invalid credentials'
)
INVALID_JWT_TOKEN = StatusValue(
401, AuthenticationStatus.INVALID_JWT_TOKEN, 'Invalid JWT token'
)
INVALID_OTP_CODE = StatusValue(
401, AuthenticationStatus.INVALID_OTP_CODE, 'Invalid OTP code'
)
INVALID_METHOD = StatusValue(
405, AuthenticationStatus.INVALID_METHOD, 'Invalid method'
)
MISSING_OTP_CODE = StatusValue(
401, AuthenticationStatus.MISSING_OTP_CODE, 'Missing OTP code'
)
MISSING_PASSWORD = StatusValue(
400, AuthenticationStatus.MISSING_PASSWORD, 'Missing password'
)
MISSING_USERNAME = StatusValue(
400, AuthenticationStatus.MISSING_USERNAME, 'Missing username'
)
PASSWORD_MISMATCH = StatusValue(
400, AuthenticationStatus.PASSWORD_MISMATCH, 'Password mismatch'
)
REGISTRATION_DISABLED = StatusValue(
401, AuthenticationStatus.REGISTRATION_DISABLED, 'Registrations are disabled'
)
REGISTRATION_REQUIRED = StatusValue(
412, AuthenticationStatus.REGISTRATION_REQUIRED, 'Please create a user first'
)
def to_dict(self):
return {
'code': self.value[0],
'message': self.value[1],
'error': self.value[1].name,
'message': self.value[2],
}
def to_response(self):
return jsonify(self.to_dict()), self.value[0]
@staticmethod
def by_status(status: AuthenticationStatus):
for auth_status in UserAuthStatus:
if auth_status.value[1] == status:
return auth_status
return None

View file

@ -5,7 +5,7 @@ from threading import Thread
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
from platypush.backend.http.app.utils.auth import UserAuthStatus, get_auth_status
from ..mixins import MessageType, PubSubMixin
@ -25,7 +25,7 @@ class WSRoute(WebSocketHandler, Thread, PubSubMixin, ABC):
def open(self, *_, **__):
auth_status = get_auth_status(self.request)
if auth_status != AuthStatus.OK:
if auth_status != UserAuthStatus.OK:
self.close(code=1008, reason=auth_status.value.message) # Policy Violation
return

View file

@ -1,4 +1,12 @@
<template>
<div id="error" v-if="initError">
<h1>Initialization error</h1>
<p>{{ initError }}</p>
</div>
<Loading v-else-if="!initialized" />
<div id="app-container" v-else>
<Events ref="events" v-if="hasWebsocket" />
<Notifications ref="notifications" />
<VoiceAssistant ref="voice-assistant" v-if="hasAssistant" />
@ -10,11 +18,13 @@
<DropdownContainer />
<router-view />
</div>
</template>
<script>
import ConfirmDialog from "@/components/elements/ConfirmDialog";
import DropdownContainer from "@/components/elements/DropdownContainer";
import Loading from "@/components/Loading";
import Notifications from "@/components/Notifications";
import Utils from "@/Utils";
import Events from "@/Events";
@ -29,6 +39,7 @@ export default {
ConfirmDialog,
DropdownContainer,
Events,
Loading,
Notifications,
Ntfy,
Pushbullet,
@ -41,6 +52,8 @@ export default {
userAuthenticated: false,
connected: false,
pwaInstallEvent: null,
initialized: false,
initError: null,
}
},
@ -84,8 +97,18 @@ export default {
}
},
created() {
this.initConfig()
async created() {
try {
await this.initConfig()
} catch (e) {
const code = e?.response?.data?.code
if (![401, 403, 412].includes(code)) {
this.initError = e
console.error('Initialization error', e)
}
} finally {
this.initialized = true
}
},
beforeMount() {
@ -125,6 +148,11 @@ html, body {
overflow: auto;
}
#app-container {
width: 100%;
height: 100%;
}
#app {
font-family: BlinkMacSystemFont,-apple-system,Avenir,Segoe UI,Roboto,Oxygen,Ubuntu,Cantarell,Fira Sans,Droid Sans,Helvetica Neue,Helvetica,Verdana,Arial,sans-serif;
-webkit-font-smoothing: antialiased;

View file

@ -8,6 +8,7 @@ $default-bg-6: #e4eae8 !default;
$default-bg-7: #e4e4e4 !default;
$ok-fg: #17ad17 !default;
$error-fg: #ad1717 !default;
$error-bg: #ffaaa2 !default;
$tile-bg: linear-gradient(90deg, rgba(9,174,128,1) 0%, rgba(71,226,179,1) 120%);
$tile-fg: white;
$tile-hover-bg: linear-gradient(90deg, rgba(41,216,159,1) 0%, rgba(9,188,138,1) 70%);

View file

@ -1,6 +1,8 @@
<template>
<div class="login-container">
<form class="login" method="POST">
<Loading v-if="!initialized" />
<div class="login-container" v-else>
<form class="login" method="POST" @submit="submitForm" v-if="!isAuthenticated">
<div class="header">
<span class="logo">
<img src="/logo.svg" alt="logo" />
@ -10,24 +12,30 @@
<div class="row">
<label>
<input type="text" name="username" placeholder="Username">
<input type="text" name="username" :disabled="authenticating" placeholder="Username" ref="username">
</label>
</div>
<div class="row">
<label>
<input type="password" name="password" placeholder="Password">
<input type="password" name="password" :disabled="authenticating" placeholder="Password">
</label>
</div>
<div class="row" v-if="_register">
<div class="row" v-if="register">
<label>
<input type="password" name="confirm_password" placeholder="Confirm password">
<input type="password" name="confirm_password" :disabled="authenticating" placeholder="Confirm password">
</label>
</div>
<div class="row buttons">
<input type="submit" class="btn btn-primary" :value="_register ? 'Register' : 'Login'">
<button type="submit"
class="btn btn-primary"
:class="{loading: authenticating}"
:disabled="authenticating">
<Loading v-if="authenticating" />
{{ register ? 'Register' : 'Login' }}
</button>
</div>
<div class="row pull-right">
@ -36,16 +44,26 @@
Keep me logged in on this device &nbsp;
</label>
</div>
<div class="auth-error" v-if="authError">
{{ authError }}
</div>
</form>
</div>
</template>
<script>
import Loading from "@/components/Loading";
import Utils from "@/Utils";
import axios from 'axios'
export default {
name: "Login",
mixins: [Utils],
components: {
Loading,
},
props: {
// Set to true for a registration form, false for a login form
register: {
@ -56,10 +74,86 @@ export default {
},
computed: {
_register() {
return this.parseBoolean(this.register)
redirect() {
return this.$route.query.redirect?.length ? this.$route.query.redirect : '/'
},
},
data() {
return {
authError: null,
authenticating: false,
isAuthenticated: false,
initialized: false,
}
},
methods: {
async submitForm(e) {
e.preventDefault();
const form = e.target
const data = new FormData(form)
const url = `/auth?type=${this.register ? 'register' : 'login'}`
if (this.register && data.get('password') !== data.get('confirm_password')) {
this.authError = "Passwords don't match"
return
}
this.authError = null
try {
const authStatus = await axios.post(url, data)
const sessionToken = authStatus?.data?.session_token
if (sessionToken) {
const expiresAt = authStatus.expires_at ? Date.parse(authStatus.expires_at) : null
this.isAuthenticated = true
this.setCookie('session_token', sessionToken, {
expires: expiresAt,
})
window.location.href = authStatus.redirect || this.redirect
} else {
this.authError = "Invalid credentials"
}
} catch (e) {
this.authError = e.response.data.message || e.response.data.error
if (e.response?.status === 401) {
this.authError = this.authError || "Invalid credentials"
} else {
this.authError = this.authError || "An error occurred while processing the request"
if (e.response)
console.error(e.response.status, e.response.data)
else
console.error(e)
}
}
},
async checkAuth() {
try {
const authStatus = await axios.get('/auth')
if (authStatus.data.session_token) {
this.isAuthenticated = true
window.location.href = authStatus.redirect || this.redirect
}
} catch (e) {
this.isAuthenticated = false
} finally {
this.initialized = true
}
},
},
async created() {
await this.checkAuth()
},
async mounted() {
this.$nextTick(() => {
this.$refs.username?.focus()
})
},
}
</script>
@ -116,7 +210,7 @@ form {
width: 100%;
}
input[type=submit],
[type=submit],
input[type=password] {
border-radius: 1em;
}
@ -133,10 +227,33 @@ form {
.buttons {
text-align: center;
input[type=submit] {
[type=submit] {
position: relative;
width: 6em;
height: 2.5em;
padding: .5em .75em;
display: inline-flex;
align-items: center;
justify-content: center;
&.loading {
background: none;
border: none;
cursor: not-allowed;
}
}
}
.auth-error {
background: $error-bg;
display: flex;
margin: 1em 0 -2em 0;
padding: .5em;
align-items: center;
justify-content: center;
border: $notification-error-border;
border-radius: 1em;
}
}
a {

View file

@ -34,16 +34,17 @@ module.exports = {
devServer: {
proxy: {
'^/auth': httpProxy,
'^/camera/': httpProxy,
'^/execute': httpProxy,
'^/file': httpProxy,
'^/logo.svg': httpProxy,
'^/logout': httpProxy,
'^/media/': httpProxy,
'^/sound/': httpProxy,
'^/ws/events': wsProxy,
'^/ws/requests': wsProxy,
'^/ws/shell': wsProxy,
'^/execute': httpProxy,
'^/file': httpProxy,
'^/auth': httpProxy,
'^/camera/': httpProxy,
'^/sound/': httpProxy,
'^/media/': httpProxy,
'^/logo.svg': httpProxy,
}
}
};

View file

@ -1,4 +1,4 @@
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional, Tuple, Union
from platypush.plugins import Plugin, action
from platypush.user import UserManager
@ -66,15 +66,46 @@ class UserPlugin(Plugin):
}
@action
def authenticate_user(self, username, password):
def authenticate_user(
self,
username: str,
password: str,
code: Optional[str] = None,
return_details: bool = False,
) -> Union[bool, Tuple[bool, str]]:
"""
Authenticate a user.
:return: True if the provided username and password are correct, False
otherwise.
"""
:param username: Username.
:param password: Password.
:param code: Optional 2FA code, if 2FA is enabled for the user.
:param return_error_details: If True then return the error details in
case of authentication failure.
:return: If ``return_details`` is False (default), the action returns
True if the provided credentials are valid, False otherwise.
If ``return_details`` is True then the action returns a tuple
(authenticated, error_details) where ``authenticated`` is True if
the provided credentials are valid, False otherwise, and
``error_details`` is a string containing the error details in case
of authentication failure. Supported error details are:
return bool(self.user_manager.authenticate_user(username, password))
- ``invalid_credentials``: Invalid username or password.
- ``invalid_otp_code``: Invalid 2FA code.
- ``missing_otp_code``: Username/password are correct, but a 2FA
code is required for the user.
"""
response = self.user_manager.authenticate_user(
username, password, code=code, return_error=return_details
)
if return_details:
assert (
isinstance(response, tuple) and len(response) == 2
), 'Invalid response from authenticate_user'
return response[0], response[1].value
return response
@action
def update_password(self, username, old_password, new_password):
@ -111,7 +142,7 @@ class UserPlugin(Plugin):
return None, "No such user: {}".format(username)
@action
def create_session(self, username, password, expires_at=None):
def create_session(self, username, password, code=None, expires_at=None):
"""
Create a user session.
@ -130,7 +161,7 @@ class UserPlugin(Plugin):
"""
session = self.user_manager.create_user_session(
username=username, password=password, expires_at=expires_at
username=username, password=password, code=code, expires_at=expires_at
)
if not session:
@ -140,9 +171,9 @@ class UserPlugin(Plugin):
'session_token': session.session_token,
'user_id': session.user_id,
'created_at': session.created_at.isoformat(),
'expires_at': session.expires_at.isoformat()
if session.expires_at
else None,
'expires_at': (
session.expires_at.isoformat() if session.expires_at else None # type: ignore
),
}
@action

View file

@ -1,11 +1,12 @@
import base64
import datetime
import enum
import hashlib
import json
import os
import random
import time
from typing import Optional, Dict
from typing import List, Optional, Dict, Tuple, Union
import rsa
@ -13,12 +14,32 @@ from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import make_transient
from platypush.common.db import Base
from platypush.config import Config
from platypush.context import get_plugin
from platypush.exceptions.user import (
InvalidJWTTokenException,
InvalidCredentialsException,
)
from platypush.utils import get_or_generate_jwt_rsa_key_pair, utcnow
from platypush.utils import get_or_generate_stored_rsa_key_pair, utcnow
class AuthenticationStatus(enum.Enum):
"""
Enum for authentication errors.
"""
OK = ''
INVALID_AUTH_TYPE = 'invalid_auth_type'
INVALID_CREDENTIALS = 'invalid_credentials'
INVALID_METHOD = 'invalid_method'
INVALID_JWT_TOKEN = 'invalid_jwt_token'
INVALID_OTP_CODE = 'invalid_otp_code'
MISSING_OTP_CODE = 'missing_otp_code'
MISSING_PASSWORD = 'missing_password'
MISSING_USERNAME = 'missing_username'
PASSWORD_MISMATCH = 'password_mismatch'
REGISTRATION_DISABLED = 'registration_disabled'
REGISTRATION_REQUIRED = 'registration_required'
class UserManager:
@ -26,6 +47,14 @@ class UserManager:
Main class for managing platform users
"""
_otp_workdir = os.path.join(Config.get_workdir(), 'otp')
_otp_keyfile = os.path.join(_otp_workdir, 'key')
_otp_keyfile_pub = f'{_otp_keyfile}.pub'
_jwt_workdir = os.path.join(Config.get_workdir(), 'jwt')
_jwt_keyfile = os.path.join(_jwt_workdir, 'id_rsa')
_jwt_keyfile_pub = f'{_jwt_keyfile}.pub'
def __init__(self):
db_plugin = get_plugin('db')
assert db_plugin, 'Database plugin not configured'
@ -41,6 +70,44 @@ class UserManager:
def _get_session(self, *args, **kwargs):
return self.db.get_session(self.db.get_engine(), *args, **kwargs)
@classmethod
def _get_jwt_rsa_key_pair(cls):
"""
Get or generate the JWT RSA key pair.
"""
return get_or_generate_stored_rsa_key_pair(cls._jwt_keyfile, size=2048)
@classmethod
def _get_or_generate_otp_rsa_key_pair(cls):
"""
Get or generate the OTP RSA key pair.
"""
return get_or_generate_stored_rsa_key_pair(cls._otp_keyfile, size=4096)
@staticmethod
def _encrypt(data: Union[str, bytes, dict, list, tuple], key: rsa.PublicKey) -> str:
"""
Encrypt the data using the given RSA public key.
"""
if isinstance(data, tuple):
data = list(data)
if isinstance(data, (dict, list)):
data = json.dumps(data, sort_keys=True, indent=None)
if isinstance(data, str):
data = data.encode('ascii')
return base64.b64encode(rsa.encrypt(data, key)).decode()
@staticmethod
def _decrypt(data: Union[str, bytes], key: rsa.PrivateKey) -> str:
"""
Decrypt the data using the given RSA private key.
"""
if isinstance(data, str):
data = data.encode('ascii')
return rsa.decrypt(base64.b64decode(data), key).decode()
def get_user(self, username):
with self._get_session() as session:
user = self._get_user(session, username)
@ -88,9 +155,9 @@ class UserManager:
return self._mask_password(user)
def update_password(self, username, old_password, new_password):
def update_password(self, username, old_password, new_password, code=None):
with self._get_session(locked=True) as session:
if not self._authenticate_user(session, username, old_password):
if not self._authenticate_user(session, username, old_password, code=code):
return False
user = self._get_user(session, username)
@ -103,12 +170,22 @@ class UserManager:
session.commit()
return True
def authenticate_user(self, username, password):
def authenticate_user(self, username, password, code=None, return_error=False):
with self._get_session() as session:
return self._authenticate_user(session, username, password)
return self._authenticate_user(
session, username, password, code=code, return_error=return_error
)
def authenticate_user_session(self, session_token):
def authenticate_user_session(self, session_token, with_error=False):
with self._get_session() as session:
users_count = session.query(User).count()
if not users_count:
return (
(None, None, AuthenticationStatus.REGISTRATION_REQUIRED)
if with_error
else (None, None)
)
user_session = (
session.query(UserSession)
.filter_by(session_token=session_token)
@ -122,10 +199,21 @@ class UserManager:
)
if not user_session or (expires_at and expires_at < utcnow()):
return None, None
return (
(None, None, AuthenticationStatus.INVALID_CREDENTIALS)
if with_error
else (None, None)
)
user = session.query(User).filter_by(user_id=user_session.user_id).first()
return self._mask_password(user), user_session
return (
(self._mask_password(user), user_session, AuthenticationStatus.OK)
if with_error
else (
self._mask_password(user),
user_session,
)
)
def delete_user(self, username):
with self._get_session(locked=True) as session:
@ -158,11 +246,25 @@ class UserManager:
session.commit()
return True
def create_user_session(self, username, password, expires_at=None):
def create_user_session(
self,
username,
password,
code=None,
expires_at=None,
error_on_invalid_credentials=False,
):
with self._get_session(locked=True) as session:
user = self._authenticate_user(session, username, password)
user, status = self._authenticate_user( # type: ignore
session,
username,
password,
code=code,
return_error=error_on_invalid_credentials,
)
if not user:
return None
return None if not error_on_invalid_credentials else (None, status)
if expires_at:
if isinstance(expires_at, (int, float)):
@ -180,7 +282,53 @@ class UserManager:
session.add(user_session)
session.commit()
return user_session
return user_session, (
AuthenticationStatus.OK if not error_on_invalid_credentials else status
)
def create_otp_secret(
self, username: str, expires_at: Optional[datetime.datetime] = None
):
pubkey, _ = self._get_or_generate_otp_rsa_key_pair()
# Generate a new OTP secret and encrypt it with the OTP RSA key pair
otp_secret = "".join(
random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(32)
)
encrypted_secret = self._encrypt(otp_secret, pubkey)
with self._get_session(locked=True) as session:
user = self._get_user(session, username)
assert user, f'No such user: {username}'
# Create a new OTP secret
user_otp = UserOtp(
user_id=user.user_id,
otp_secret=encrypted_secret,
created_at=utcnow(),
expires_at=expires_at,
)
# Remove any existing OTP secret and replace it with the new one
session.query(UserOtp).filter_by(user_id=user.user_id).delete()
session.add(user_otp)
session.commit()
return user_otp
def get_otp_secret(self, username: str) -> Optional[str]:
with self._get_session() as session:
user = self._get_user(session, username)
if not user:
return None
user_otp = session.query(UserOtp).filter_by(user_id=user.user_id).first()
if not user_otp:
return None
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
return self._decrypt(user_otp.otp_secret, priv_key)
@staticmethod
def _get_user(session, username):
@ -268,20 +416,17 @@ class UserManager:
if not user:
raise InvalidCredentialsException()
pub_key, _ = get_or_generate_jwt_rsa_key_pair()
payload = json.dumps(
pub_key, _ = self._get_jwt_rsa_key_pair()
return self._encrypt(
{
'username': username,
'password': password,
'created_at': datetime.datetime.now().timestamp(),
'expires_at': expires_at.timestamp() if expires_at else None,
},
sort_keys=True,
indent=None,
pub_key,
)
return base64.b64encode(rsa.encrypt(payload.encode('ascii'), pub_key)).decode()
def validate_jwt_token(self, token: str) -> Dict[str, str]:
"""
Validate a JWT token.
@ -299,14 +444,10 @@ class UserManager:
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
"""
_, priv_key = get_or_generate_jwt_rsa_key_pair()
_, priv_key = self._get_jwt_rsa_key_pair()
try:
payload = json.loads(
rsa.decrypt(base64.b64decode(token.encode('ascii')), priv_key).decode(
'ascii'
)
)
payload = json.loads(self._decrypt(token, priv_key))
except (TypeError, ValueError) as e:
raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') from e
@ -323,23 +464,160 @@ class UserManager:
return payload
def _authenticate_user(self, session, username, password):
def _authenticate_user(
self,
session,
username: str,
password: str,
code: Optional[str] = None,
return_error: bool = False,
) -> Union[Optional['User'], Tuple[Optional['User'], 'AuthenticationStatus']]:
"""
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.
:return: :class:`platypush.user.User` instance if the user exists and
the password is valid, ``None`` otherwise.
"""
user = self._get_user(session, username)
if not user:
return None
# The user does not exist
if not user:
return (
None
if not return_error
else (None, AuthenticationStatus.INVALID_CREDENTIALS)
)
# The password is not correct
if not self._check_password(
password,
user.password,
bytes.fromhex(user.password_salt) if user.password_salt else None,
user.hmac_iterations,
):
return None
return (
None
if not return_error
else (None, AuthenticationStatus.INVALID_CREDENTIALS)
)
return user
otp_secret = self.get_otp_secret(username)
# The user doesn't have 2FA enabled and the password is correct:
# authentication successful
if not otp_secret:
return user if not return_error else (user, AuthenticationStatus.OK)
# The user has 2FA enabled but the code is missing
if not code:
return (
None
if not return_error
else (None, AuthenticationStatus.MISSING_OTP_CODE)
)
if self.validate_otp_code(username, code):
return user if not return_error else (user, AuthenticationStatus.OK)
if not self.validate_backup_code(username, code):
return (
None
if not return_error
else (None, AuthenticationStatus.INVALID_OTP_CODE)
)
return user if not return_error else (user, AuthenticationStatus.OK)
def refresh_user_backup_codes(self, username: str):
"""
Refresh the backup codes for a user with 2FA enabled.
"""
with self._get_session(locked=True) as session:
user = self._get_user(session, username)
if not user:
return False
session.query(UserBackupCode).filter_by(user_id=user.user_id).delete()
pub_key, _ = self._get_or_generate_otp_rsa_key_pair()
for _ in range(10):
backup_code = "".join(
random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(16)
)
user_backup_code = UserBackupCode(
user_id=user.user_id,
code=self._encrypt(backup_code, pub_key),
created_at=utcnow(),
expires_at=utcnow() + datetime.timedelta(days=30),
)
session.add(user_backup_code)
session.commit()
return True
def get_user_backup_codes(self, username: str) -> List['UserBackupCode']:
with self._get_session() as session:
user = self._get_user(session, username)
if not user:
return []
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
codes = session.query(UserBackupCode).filter_by(user_id=user.user_id).all()
for code in codes:
code.code = self._decrypt(code.code, priv_key)
return codes
def validate_backup_code(self, username: str, code: str) -> bool:
with self._get_session() as session:
user = self._get_user(session, username)
if not user:
return False
pub_key, _ = self._get_or_generate_otp_rsa_key_pair()
user_backup_code = (
session.query(UserBackupCode)
.filter_by(user_id=user.user_id, code=self._encrypt(code, pub_key))
.first()
)
if not user_backup_code:
return False
session.delete(user_backup_code)
session.commit()
return True
def validate_otp_code(self, username: str, code: str) -> bool:
otp = get_plugin('otp')
assert otp
with self._get_session() as session:
user = self._get_user(session, username)
if not user:
return False
otp_secret = self.get_otp_secret(username)
if not otp_secret:
return False
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
otp_secret = self._decrypt(otp_secret, priv_key)
return otp.verify_time_otp(otp=code, secret=otp_secret)
def disable_mfa(self, username: str):
with self._get_session(locked=True) as session:
user = self._get_user(session, username)
if not user:
return False
session.query(UserOtp).filter_by(user_id=user.user_id).delete()
session.query(UserBackupCode).filter_by(user_id=user.user_id).delete()
session.commit()
return True
class User(Base):
@ -370,4 +648,31 @@ class UserSession(Base):
expires_at = Column(DateTime)
class UserOtp(Base):
"""
Models the UserOtp table, which contains the OTP secrets for each user.
"""
__tablename__ = 'user_otp'
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
otp_secret = Column(String, nullable=False, unique=True)
created_at = Column(DateTime)
expires_at = Column(DateTime)
class UserBackupCode(Base):
"""
Models the UserBackupCode table, which contains the backup codes for each
user with 2FA enabled.
"""
__tablename__ = 'user_backup_code'
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
code = Column(String, nullable=False, unique=True)
created_at = Column(DateTime)
expires_at = Column(DateTime)
# vim:sw=4:ts=4:et:

View file

@ -532,7 +532,12 @@ def generate_rsa_key_pair(
private_key_str = priv_key.save_pkcs1('PEM').decode()
if key_file:
pathlib.Path(os.path.dirname(os.path.expanduser(key_file))).mkdir(
parents=True, exist_ok=True
)
logger.info('Saving private key to %s', key_file)
with open(os.path.expanduser(key_file), 'w') as f1, open(
os.path.expanduser(key_file) + '.pub', 'w'
) as f2:
@ -543,14 +548,20 @@ def generate_rsa_key_pair(
return pub_key, priv_key
def get_or_generate_jwt_rsa_key_pair():
def get_or_generate_stored_rsa_key_pair(
keyfile: str, size: int = 2048
) -> Tuple[PublicKey, PrivateKey]:
"""
Get or generate a JWT RSA key pair.
"""
from platypush.config import Config
Get or generate an RSA key pair and store it in the given key file.
key_dir = os.path.join(Config.get_workdir(), 'jwt')
priv_key_file = os.path.join(key_dir, 'id_rsa')
The private key will be stored in the given file, while the public key will
be stored in ``<keyfile>.pub``.
:param keyfile: Path to the key file.
:param size: Key size in bits (default: 2048).
"""
keydir = os.path.dirname(os.path.expanduser(keyfile))
priv_key_file = os.path.join(keydir, os.path.basename(keyfile))
pub_key_file = priv_key_file + '.pub'
if os.path.isfile(priv_key_file) and os.path.isfile(pub_key_file):
@ -560,8 +571,15 @@ def get_or_generate_jwt_rsa_key_pair():
PrivateKey.load_pkcs1(f2.read().encode()),
)
pathlib.Path(key_dir).mkdir(parents=True, exist_ok=True, mode=0o755)
return generate_rsa_key_pair(priv_key_file, size=2048)
pub_key, priv_key = generate_rsa_key_pair(priv_key_file, size=size)
pathlib.Path(keydir).mkdir(parents=True, exist_ok=True, mode=0o755)
with open(pub_key_file, 'w') as f1, open(priv_key_file, 'w') as f2:
f1.write(pub_key.save_pkcs1('PEM').decode())
f2.write(priv_key.save_pkcs1('PEM').decode())
os.chmod(priv_key_file, 0o600)
return pub_key, priv_key
def get_enabled_plugins() -> dict: