[core] Refactoring user/authentication layer.

- Separated the user model/db classes from the `UserManager`.
- More consistent naming for the flag on the `authenticate_*` functions
  that enables returning a tuple with the authentication status - all
  those flags are now named `with_status`.
This commit is contained in:
Fabio Manganiello 2024-07-23 22:44:40 +02:00
parent ee27b2c4c6
commit 2033f9760a
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
6 changed files with 141 additions and 105 deletions

View file

@ -63,7 +63,7 @@ def _session_auth():
redirect_page = request.args.get('redirect') or '/' redirect_page = request.args.get('redirect') or '/'
if session_token: if session_token:
user, session = user_manager.authenticate_user_session(session_token) # type: ignore user, session = user_manager.authenticate_user_session(session_token)[:2]
if user and session: if user and session:
return _dump_session(session, redirect_page) return _dump_session(session, redirect_page)
@ -78,7 +78,7 @@ def _session_auth():
password=password, password=password,
code=code, code=code,
expires_at=expires, expires_at=expires,
error_on_invalid_credentials=True, with_status=True,
) )
if session: if session:
@ -97,7 +97,7 @@ def _register_route():
redirect_page = request.args.get('redirect') or '/' redirect_page = request.args.get('redirect') or '/'
if session_token: if session_token:
user, session = user_manager.authenticate_user_session(session_token) # type: ignore user, session = user_manager.authenticate_user_session(session_token)[:2]
if user and session: if user and session:
return _dump_session(session, redirect_page) return _dump_session(session, redirect_page)
@ -124,7 +124,7 @@ def _register_route():
username=username, username=username,
password=password, password=password,
expires_at=(utcnow() + datetime.timedelta(days=365) if remember else None), expires_at=(utcnow() + datetime.timedelta(days=365) if remember else None),
error_on_invalid_credentials=True, with_status=True,
) )
if session: if session:
@ -144,7 +144,7 @@ def _auth_get():
session_token = request.cookies.get('session_token') session_token = request.cookies.get('session_token')
redirect_page = request.args.get('redirect') or '/' redirect_page = request.args.get('redirect') or '/'
user, session, status = user_manager.authenticate_user_session( # type: ignore user, session, status = user_manager.authenticate_user_session( # type: ignore
session_token, with_error=True session_token, with_status=True
) )
if user and session: if user and session:

View file

@ -12,7 +12,7 @@ __routes__ = [
@logout.route('/logout', methods=['GET', 'POST']) @logout.route('/logout', methods=['GET', 'POST'])
def logout(): def logout_route():
"""Logout page""" """Logout page"""
user_manager = UserManager() user_manager = UserManager()
redirect_page = request.args.get( redirect_page = request.args.get(
@ -23,7 +23,7 @@ def logout():
if not session_token: if not session_token:
abort(417, 'Not logged in') abort(417, 'Not logged in')
user, _ = user_manager.authenticate_user_session(session_token) user, _ = user_manager.authenticate_user_session(session_token)[:2]
if not user: if not user:
abort(403, 'Invalid session token') abort(403, 'Invalid session token')

View file

@ -106,7 +106,7 @@ def authenticate_session(req):
user_session_token = get_cookie(req, 'session_token') user_session_token = get_cookie(req, 'session_token')
if user_session_token: if user_session_token:
user, _ = user_manager.authenticate_user_session(user_session_token) user, _ = user_manager.authenticate_user_session(user_session_token)[:2]
return user is not None return user is not None

View file

@ -21,7 +21,7 @@ class UserPlugin(Plugin):
executing_user=None, executing_user=None,
executing_user_password=None, executing_user_password=None,
session_token=None, session_token=None,
**kwargs **kwargs,
): ):
""" """
Create a user. This action needs to be executed by an already existing Create a user. This action needs to be executed by an already existing
@ -50,7 +50,7 @@ class UserPlugin(Plugin):
if not self.user_manager.authenticate_user( if not self.user_manager.authenticate_user(
executing_user, executing_user_password executing_user, executing_user_password
): ):
user, _ = self.user_manager.authenticate_user_session(session_token) user, _ = self.user_manager.authenticate_user_session(session_token)[:2]
if not user: if not user:
return None, "Invalid credentials and/or session_token" return None, "Invalid credentials and/or session_token"
@ -132,14 +132,14 @@ class UserPlugin(Plugin):
if not self.user_manager.authenticate_user( if not self.user_manager.authenticate_user(
executing_user, executing_user_password executing_user, executing_user_password
): ):
user, _ = self.user_manager.authenticate_user_session(session_token) user, _ = self.user_manager.authenticate_user_session(session_token)[:2]
if not user: if not user:
return None, "Invalid credentials and/or session_token" return None, "Invalid credentials and/or session_token"
try: try:
return self.user_manager.delete_user(username) return self.user_manager.delete_user(username)
except NameError: except NameError:
return None, "No such user: {}".format(username) return None, f"No such user: {username}"
@action @action
def create_session(self, username, password, code=None, expires_at=None): def create_session(self, username, password, code=None, expires_at=None):
@ -164,6 +164,9 @@ class UserPlugin(Plugin):
username=username, password=password, code=code, expires_at=expires_at username=username, password=password, code=code, expires_at=expires_at
) )
if isinstance(session, tuple):
session = session[0]
if not session: if not session:
return None, "Invalid credentials" return None, "Invalid credentials"
@ -193,9 +196,7 @@ class UserPlugin(Plugin):
""" """
user, _ = self.user_manager.authenticate_user_session( user, _ = self.user_manager.authenticate_user_session(session_token)[:2]
session_token=session_token
)
if not user: if not user:
return None, 'Invalid session token' return None, 'Invalid session token'

View file

@ -1,6 +1,5 @@
import base64 import base64
import datetime import datetime
import enum
import hashlib import hashlib
import json import json
import os import os
@ -10,7 +9,6 @@ from typing import List, Optional, Dict, Tuple, Union
import rsa import rsa
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import make_transient from sqlalchemy.orm import make_transient
from platypush.common.db import Base from platypush.common.db import Base
@ -22,24 +20,7 @@ from platypush.exceptions.user import (
) )
from platypush.utils import get_or_generate_stored_rsa_key_pair, utcnow from platypush.utils import get_or_generate_stored_rsa_key_pair, utcnow
from ._model import User, UserSession, UserOtp, UserBackupCode, AuthenticationStatus
class AuthenticationStatus(enum.Enum):
"""
Enum for authentication errors.
"""
OK = ''
INVALID_AUTH_TYPE = 'invalid_auth_type'
INVALID_CREDENTIALS = 'invalid_credentials'
INVALID_METHOD = 'invalid_method'
INVALID_JWT_TOKEN = 'invalid_jwt_token'
INVALID_OTP_CODE = 'invalid_otp_code'
MISSING_OTP_CODE = 'missing_otp_code'
MISSING_PASSWORD = 'missing_password'
MISSING_USERNAME = 'missing_username'
PASSWORD_MISMATCH = 'password_mismatch'
REGISTRATION_DISABLED = 'registration_disabled'
REGISTRATION_REQUIRED = 'registration_required'
class UserManager: class UserManager:
@ -173,16 +154,16 @@ class UserManager:
def authenticate_user(self, username, password, code=None, return_error=False): def authenticate_user(self, username, password, code=None, return_error=False):
with self._get_session() as session: with self._get_session() as session:
return self._authenticate_user( return self._authenticate_user(
session, username, password, code=code, return_error=return_error session, username, password, code=code, with_status=return_error
) )
def authenticate_user_session(self, session_token, with_error=False): def authenticate_user_session(self, session_token, with_status=False):
with self._get_session() as session: with self._get_session() as session:
users_count = session.query(User).count() users_count = session.query(User).count()
if not users_count: if not users_count:
return ( return (
(None, None, AuthenticationStatus.REGISTRATION_REQUIRED) (None, None, AuthenticationStatus.REGISTRATION_REQUIRED)
if with_error if with_status
else (None, None) else (None, None)
) )
@ -201,14 +182,14 @@ class UserManager:
if not user_session or (expires_at and expires_at < utcnow()): if not user_session or (expires_at and expires_at < utcnow()):
return ( return (
(None, None, AuthenticationStatus.INVALID_CREDENTIALS) (None, None, AuthenticationStatus.INVALID_CREDENTIALS)
if with_error if with_status
else (None, None) else (None, None)
) )
user = session.query(User).filter_by(user_id=user_session.user_id).first() user = session.query(User).filter_by(user_id=user_session.user_id).first()
return ( return (
(self._mask_password(user), user_session, AuthenticationStatus.OK) (self._mask_password(user), user_session, AuthenticationStatus.OK)
if with_error if with_status
else ( else (
self._mask_password(user), self._mask_password(user),
user_session, user_session,
@ -252,7 +233,7 @@ class UserManager:
password, password,
code=None, code=None,
expires_at=None, expires_at=None,
error_on_invalid_credentials=False, with_status=False,
): ):
with self._get_session(locked=True) as session: with self._get_session(locked=True) as session:
user, status = self._authenticate_user( # type: ignore user, status = self._authenticate_user( # type: ignore
@ -260,11 +241,11 @@ class UserManager:
username, username,
password, password,
code=code, code=code,
return_error=error_on_invalid_credentials, with_status=with_status,
) )
if not user: if not user:
return None if not error_on_invalid_credentials else (None, status) return None if not with_status else (None, status)
if expires_at: if expires_at:
if isinstance(expires_at, (int, float)): if isinstance(expires_at, (int, float)):
@ -283,7 +264,7 @@ class UserManager:
session.add(user_session) session.add(user_session)
session.commit() session.commit()
return user_session, ( return user_session, (
AuthenticationStatus.OK if not error_on_invalid_credentials else status AuthenticationStatus.OK if not with_status else status
) )
def create_otp_secret( def create_otp_secret(
@ -470,7 +451,7 @@ class UserManager:
username: str, username: str,
password: str, password: str,
code: Optional[str] = None, code: Optional[str] = None,
return_error: bool = False, with_status: bool = False,
) -> Union[Optional['User'], Tuple[Optional['User'], 'AuthenticationStatus']]: ) -> Union[Optional['User'], Tuple[Optional['User'], 'AuthenticationStatus']]:
""" """
:return: :class:`platypush.user.User` instance if the user exists and :return: :class:`platypush.user.User` instance if the user exists and
@ -482,7 +463,7 @@ class UserManager:
if not user: if not user:
return ( return (
None None
if not return_error if not with_status
else (None, AuthenticationStatus.INVALID_CREDENTIALS) else (None, AuthenticationStatus.INVALID_CREDENTIALS)
) )
@ -495,7 +476,7 @@ class UserManager:
): ):
return ( return (
None None
if not return_error if not with_status
else (None, AuthenticationStatus.INVALID_CREDENTIALS) else (None, AuthenticationStatus.INVALID_CREDENTIALS)
) )
@ -504,27 +485,27 @@ class UserManager:
# The user doesn't have 2FA enabled and the password is correct: # The user doesn't have 2FA enabled and the password is correct:
# authentication successful # authentication successful
if not otp_secret: if not otp_secret:
return user if not return_error else (user, AuthenticationStatus.OK) return user if not with_status else (user, AuthenticationStatus.OK)
# The user has 2FA enabled but the code is missing # The user has 2FA enabled but the code is missing
if not code: if not code:
return ( return (
None None
if not return_error if not with_status
else (None, AuthenticationStatus.MISSING_OTP_CODE) else (None, AuthenticationStatus.MISSING_OTP_CODE)
) )
if self.validate_otp_code(username, code): if self.validate_otp_code(username, code):
return user if not return_error else (user, AuthenticationStatus.OK) return user if not with_status else (user, AuthenticationStatus.OK)
if not self.validate_backup_code(username, code): if not self.validate_backup_code(username, code):
return ( return (
None None
if not return_error if not with_status
else (None, AuthenticationStatus.INVALID_OTP_CODE) else (None, AuthenticationStatus.INVALID_OTP_CODE)
) )
return user if not return_error else (user, AuthenticationStatus.OK) return user if not with_status else (user, AuthenticationStatus.OK)
def refresh_user_backup_codes(self, username: str): def refresh_user_backup_codes(self, username: str):
""" """
@ -619,60 +600,15 @@ class UserManager:
session.commit() session.commit()
return True return True
def enable_mfa(self, username: str):
with self._get_session() as session:
user = self._get_user(session, username)
if not user:
return False
class User(Base): self.create_otp_secret(username)
"""Models the User table""" self.refresh_user_backup_codes(username)
return True
__tablename__ = 'user'
__table_args__ = {'sqlite_autoincrement': True}
user_id = Column(Integer, primary_key=True)
username = Column(String, unique=True, nullable=False)
password = Column(String)
password_salt = Column(String)
hmac_iterations = Column(Integer)
created_at = Column(DateTime)
class UserSession(Base):
"""Models the UserSession table"""
__tablename__ = 'user_session'
__table_args__ = {'sqlite_autoincrement': True}
session_id = Column(Integer, primary_key=True)
session_token = Column(String, unique=True, nullable=False)
csrf_token = Column(String, unique=True)
user_id = Column(Integer, ForeignKey('user.user_id'), nullable=False)
created_at = Column(DateTime)
expires_at = Column(DateTime)
class UserOtp(Base):
"""
Models the UserOtp table, which contains the OTP secrets for each user.
"""
__tablename__ = 'user_otp'
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
otp_secret = Column(String, nullable=False, unique=True)
created_at = Column(DateTime)
expires_at = Column(DateTime)
class UserBackupCode(Base):
"""
Models the UserBackupCode table, which contains the backup codes for each
user with 2FA enabled.
"""
__tablename__ = 'user_backup_code'
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
code = Column(String, nullable=False, unique=True)
created_at = Column(DateTime)
expires_at = Column(DateTime)
# vim:sw=4:ts=4:et: # vim:sw=4:ts=4:et:

99
platypush/user/_model.py Normal file
View file

@ -0,0 +1,99 @@
from dataclasses import dataclass, field
import datetime
import enum
from typing import List, Optional
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from platypush.common.db import Base
class AuthenticationStatus(enum.Enum):
"""
Enum for authentication errors.
"""
OK = ''
INVALID_AUTH_TYPE = 'invalid_auth_type'
INVALID_CREDENTIALS = 'invalid_credentials'
INVALID_METHOD = 'invalid_method'
INVALID_JWT_TOKEN = 'invalid_jwt_token'
INVALID_OTP_CODE = 'invalid_otp_code'
MISSING_OTP_CODE = 'missing_otp_code'
MISSING_PASSWORD = 'missing_password'
MISSING_USERNAME = 'missing_username'
PASSWORD_MISMATCH = 'password_mismatch'
REGISTRATION_DISABLED = 'registration_disabled'
REGISTRATION_REQUIRED = 'registration_required'
class User(Base):
"""Models the User table"""
__tablename__ = 'user'
__table_args__ = {'sqlite_autoincrement': True}
user_id = Column(Integer, primary_key=True)
username = Column(String, unique=True, nullable=False)
password = Column(String)
password_salt = Column(String)
hmac_iterations = Column(Integer)
created_at = Column(DateTime)
class UserSession(Base):
"""Models the UserSession table"""
__tablename__ = 'user_session'
__table_args__ = {'sqlite_autoincrement': True}
session_id = Column(Integer, primary_key=True)
session_token = Column(String, unique=True, nullable=False)
csrf_token = Column(String, unique=True)
user_id = Column(Integer, ForeignKey('user.user_id'), nullable=False)
created_at = Column(DateTime)
expires_at = Column(DateTime)
class UserOtp(Base):
"""
Models the UserOtp table, which contains the OTP secrets for each user.
"""
__tablename__ = 'user_otp'
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
otp_secret = Column(String, nullable=False, unique=True)
created_at = Column(DateTime)
expires_at = Column(DateTime)
class UserBackupCode(Base):
"""
Models the UserBackupCode table, which contains the backup codes for each
user with 2FA enabled.
"""
__tablename__ = 'user_backup_code'
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
code = Column(String, nullable=False, unique=True)
created_at = Column(DateTime)
expires_at = Column(DateTime)
@dataclass
class UserResponse:
"""
Dataclass containing full information about a user (minus the password).
"""
user_id: int
username: str
otp_secret: Optional[str] = None
session_token: Optional[str] = None
created_at: Optional[datetime.datetime] = None
backup_codes: List[str] = field(default_factory=list)
# vim:sw=4:ts=4:et: