forked from platypush/platypush
[core] Refactoring user/authentication layer.
- Separated the user model/db classes from the `UserManager`. - More consistent naming for the flag on the `authenticate_*` functions that enables returning a tuple with the authentication status - all those flags are now named `with_status`.
This commit is contained in:
parent
ee27b2c4c6
commit
2033f9760a
6 changed files with 141 additions and 105 deletions
|
@ -63,7 +63,7 @@ def _session_auth():
|
|||
redirect_page = request.args.get('redirect') or '/'
|
||||
|
||||
if session_token:
|
||||
user, session = user_manager.authenticate_user_session(session_token) # type: ignore
|
||||
user, session = user_manager.authenticate_user_session(session_token)[:2]
|
||||
if user and session:
|
||||
return _dump_session(session, redirect_page)
|
||||
|
||||
|
@ -78,7 +78,7 @@ def _session_auth():
|
|||
password=password,
|
||||
code=code,
|
||||
expires_at=expires,
|
||||
error_on_invalid_credentials=True,
|
||||
with_status=True,
|
||||
)
|
||||
|
||||
if session:
|
||||
|
@ -97,7 +97,7 @@ def _register_route():
|
|||
redirect_page = request.args.get('redirect') or '/'
|
||||
|
||||
if session_token:
|
||||
user, session = user_manager.authenticate_user_session(session_token) # type: ignore
|
||||
user, session = user_manager.authenticate_user_session(session_token)[:2]
|
||||
if user and session:
|
||||
return _dump_session(session, redirect_page)
|
||||
|
||||
|
@ -124,7 +124,7 @@ def _register_route():
|
|||
username=username,
|
||||
password=password,
|
||||
expires_at=(utcnow() + datetime.timedelta(days=365) if remember else None),
|
||||
error_on_invalid_credentials=True,
|
||||
with_status=True,
|
||||
)
|
||||
|
||||
if session:
|
||||
|
@ -144,7 +144,7 @@ def _auth_get():
|
|||
session_token = request.cookies.get('session_token')
|
||||
redirect_page = request.args.get('redirect') or '/'
|
||||
user, session, status = user_manager.authenticate_user_session( # type: ignore
|
||||
session_token, with_error=True
|
||||
session_token, with_status=True
|
||||
)
|
||||
|
||||
if user and session:
|
||||
|
|
|
@ -12,7 +12,7 @@ __routes__ = [
|
|||
|
||||
|
||||
@logout.route('/logout', methods=['GET', 'POST'])
|
||||
def logout():
|
||||
def logout_route():
|
||||
"""Logout page"""
|
||||
user_manager = UserManager()
|
||||
redirect_page = request.args.get(
|
||||
|
@ -23,7 +23,7 @@ def logout():
|
|||
if not session_token:
|
||||
abort(417, 'Not logged in')
|
||||
|
||||
user, _ = user_manager.authenticate_user_session(session_token)
|
||||
user, _ = user_manager.authenticate_user_session(session_token)[:2]
|
||||
if not user:
|
||||
abort(403, 'Invalid session token')
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ def authenticate_session(req):
|
|||
user_session_token = get_cookie(req, 'session_token')
|
||||
|
||||
if user_session_token:
|
||||
user, _ = user_manager.authenticate_user_session(user_session_token)
|
||||
user, _ = user_manager.authenticate_user_session(user_session_token)[:2]
|
||||
|
||||
return user is not None
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ class UserPlugin(Plugin):
|
|||
executing_user=None,
|
||||
executing_user_password=None,
|
||||
session_token=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create a user. This action needs to be executed by an already existing
|
||||
|
@ -50,7 +50,7 @@ class UserPlugin(Plugin):
|
|||
if not self.user_manager.authenticate_user(
|
||||
executing_user, executing_user_password
|
||||
):
|
||||
user, _ = self.user_manager.authenticate_user_session(session_token)
|
||||
user, _ = self.user_manager.authenticate_user_session(session_token)[:2]
|
||||
if not user:
|
||||
return None, "Invalid credentials and/or session_token"
|
||||
|
||||
|
@ -132,14 +132,14 @@ class UserPlugin(Plugin):
|
|||
if not self.user_manager.authenticate_user(
|
||||
executing_user, executing_user_password
|
||||
):
|
||||
user, _ = self.user_manager.authenticate_user_session(session_token)
|
||||
user, _ = self.user_manager.authenticate_user_session(session_token)[:2]
|
||||
if not user:
|
||||
return None, "Invalid credentials and/or session_token"
|
||||
|
||||
try:
|
||||
return self.user_manager.delete_user(username)
|
||||
except NameError:
|
||||
return None, "No such user: {}".format(username)
|
||||
return None, f"No such user: {username}"
|
||||
|
||||
@action
|
||||
def create_session(self, username, password, code=None, expires_at=None):
|
||||
|
@ -164,6 +164,9 @@ class UserPlugin(Plugin):
|
|||
username=username, password=password, code=code, expires_at=expires_at
|
||||
)
|
||||
|
||||
if isinstance(session, tuple):
|
||||
session = session[0]
|
||||
|
||||
if not session:
|
||||
return None, "Invalid credentials"
|
||||
|
||||
|
@ -193,9 +196,7 @@ class UserPlugin(Plugin):
|
|||
|
||||
"""
|
||||
|
||||
user, _ = self.user_manager.authenticate_user_session(
|
||||
session_token=session_token
|
||||
)
|
||||
user, _ = self.user_manager.authenticate_user_session(session_token)[:2]
|
||||
if not user:
|
||||
return None, 'Invalid session token'
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import base64
|
||||
import datetime
|
||||
import enum
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
@ -10,7 +9,6 @@ from typing import List, Optional, Dict, Tuple, Union
|
|||
|
||||
import rsa
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.orm import make_transient
|
||||
|
||||
from platypush.common.db import Base
|
||||
|
@ -22,24 +20,7 @@ from platypush.exceptions.user import (
|
|||
)
|
||||
from platypush.utils import get_or_generate_stored_rsa_key_pair, utcnow
|
||||
|
||||
|
||||
class AuthenticationStatus(enum.Enum):
|
||||
"""
|
||||
Enum for authentication errors.
|
||||
"""
|
||||
|
||||
OK = ''
|
||||
INVALID_AUTH_TYPE = 'invalid_auth_type'
|
||||
INVALID_CREDENTIALS = 'invalid_credentials'
|
||||
INVALID_METHOD = 'invalid_method'
|
||||
INVALID_JWT_TOKEN = 'invalid_jwt_token'
|
||||
INVALID_OTP_CODE = 'invalid_otp_code'
|
||||
MISSING_OTP_CODE = 'missing_otp_code'
|
||||
MISSING_PASSWORD = 'missing_password'
|
||||
MISSING_USERNAME = 'missing_username'
|
||||
PASSWORD_MISMATCH = 'password_mismatch'
|
||||
REGISTRATION_DISABLED = 'registration_disabled'
|
||||
REGISTRATION_REQUIRED = 'registration_required'
|
||||
from ._model import User, UserSession, UserOtp, UserBackupCode, AuthenticationStatus
|
||||
|
||||
|
||||
class UserManager:
|
||||
|
@ -173,16 +154,16 @@ class UserManager:
|
|||
def authenticate_user(self, username, password, code=None, return_error=False):
|
||||
with self._get_session() as session:
|
||||
return self._authenticate_user(
|
||||
session, username, password, code=code, return_error=return_error
|
||||
session, username, password, code=code, with_status=return_error
|
||||
)
|
||||
|
||||
def authenticate_user_session(self, session_token, with_error=False):
|
||||
def authenticate_user_session(self, session_token, with_status=False):
|
||||
with self._get_session() as session:
|
||||
users_count = session.query(User).count()
|
||||
if not users_count:
|
||||
return (
|
||||
(None, None, AuthenticationStatus.REGISTRATION_REQUIRED)
|
||||
if with_error
|
||||
if with_status
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
|
@ -201,14 +182,14 @@ class UserManager:
|
|||
if not user_session or (expires_at and expires_at < utcnow()):
|
||||
return (
|
||||
(None, None, AuthenticationStatus.INVALID_CREDENTIALS)
|
||||
if with_error
|
||||
if with_status
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
user = session.query(User).filter_by(user_id=user_session.user_id).first()
|
||||
return (
|
||||
(self._mask_password(user), user_session, AuthenticationStatus.OK)
|
||||
if with_error
|
||||
if with_status
|
||||
else (
|
||||
self._mask_password(user),
|
||||
user_session,
|
||||
|
@ -252,7 +233,7 @@ class UserManager:
|
|||
password,
|
||||
code=None,
|
||||
expires_at=None,
|
||||
error_on_invalid_credentials=False,
|
||||
with_status=False,
|
||||
):
|
||||
with self._get_session(locked=True) as session:
|
||||
user, status = self._authenticate_user( # type: ignore
|
||||
|
@ -260,11 +241,11 @@ class UserManager:
|
|||
username,
|
||||
password,
|
||||
code=code,
|
||||
return_error=error_on_invalid_credentials,
|
||||
with_status=with_status,
|
||||
)
|
||||
|
||||
if not user:
|
||||
return None if not error_on_invalid_credentials else (None, status)
|
||||
return None if not with_status else (None, status)
|
||||
|
||||
if expires_at:
|
||||
if isinstance(expires_at, (int, float)):
|
||||
|
@ -283,7 +264,7 @@ class UserManager:
|
|||
session.add(user_session)
|
||||
session.commit()
|
||||
return user_session, (
|
||||
AuthenticationStatus.OK if not error_on_invalid_credentials else status
|
||||
AuthenticationStatus.OK if not with_status else status
|
||||
)
|
||||
|
||||
def create_otp_secret(
|
||||
|
@ -470,7 +451,7 @@ class UserManager:
|
|||
username: str,
|
||||
password: str,
|
||||
code: Optional[str] = None,
|
||||
return_error: bool = False,
|
||||
with_status: bool = False,
|
||||
) -> Union[Optional['User'], Tuple[Optional['User'], 'AuthenticationStatus']]:
|
||||
"""
|
||||
:return: :class:`platypush.user.User` instance if the user exists and
|
||||
|
@ -482,7 +463,7 @@ class UserManager:
|
|||
if not user:
|
||||
return (
|
||||
None
|
||||
if not return_error
|
||||
if not with_status
|
||||
else (None, AuthenticationStatus.INVALID_CREDENTIALS)
|
||||
)
|
||||
|
||||
|
@ -495,7 +476,7 @@ class UserManager:
|
|||
):
|
||||
return (
|
||||
None
|
||||
if not return_error
|
||||
if not with_status
|
||||
else (None, AuthenticationStatus.INVALID_CREDENTIALS)
|
||||
)
|
||||
|
||||
|
@ -504,27 +485,27 @@ class UserManager:
|
|||
# The user doesn't have 2FA enabled and the password is correct:
|
||||
# authentication successful
|
||||
if not otp_secret:
|
||||
return user if not return_error else (user, AuthenticationStatus.OK)
|
||||
return user if not with_status else (user, AuthenticationStatus.OK)
|
||||
|
||||
# The user has 2FA enabled but the code is missing
|
||||
if not code:
|
||||
return (
|
||||
None
|
||||
if not return_error
|
||||
if not with_status
|
||||
else (None, AuthenticationStatus.MISSING_OTP_CODE)
|
||||
)
|
||||
|
||||
if self.validate_otp_code(username, code):
|
||||
return user if not return_error else (user, AuthenticationStatus.OK)
|
||||
return user if not with_status else (user, AuthenticationStatus.OK)
|
||||
|
||||
if not self.validate_backup_code(username, code):
|
||||
return (
|
||||
None
|
||||
if not return_error
|
||||
if not with_status
|
||||
else (None, AuthenticationStatus.INVALID_OTP_CODE)
|
||||
)
|
||||
|
||||
return user if not return_error else (user, AuthenticationStatus.OK)
|
||||
return user if not with_status else (user, AuthenticationStatus.OK)
|
||||
|
||||
def refresh_user_backup_codes(self, username: str):
|
||||
"""
|
||||
|
@ -619,60 +600,15 @@ class UserManager:
|
|||
session.commit()
|
||||
return True
|
||||
|
||||
def enable_mfa(self, username: str):
|
||||
with self._get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
class User(Base):
|
||||
"""Models the User table"""
|
||||
|
||||
__tablename__ = 'user'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
user_id = Column(Integer, primary_key=True)
|
||||
username = Column(String, unique=True, nullable=False)
|
||||
password = Column(String)
|
||||
password_salt = Column(String)
|
||||
hmac_iterations = Column(Integer)
|
||||
created_at = Column(DateTime)
|
||||
|
||||
|
||||
class UserSession(Base):
|
||||
"""Models the UserSession table"""
|
||||
|
||||
__tablename__ = 'user_session'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
session_id = Column(Integer, primary_key=True)
|
||||
session_token = Column(String, unique=True, nullable=False)
|
||||
csrf_token = Column(String, unique=True)
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), nullable=False)
|
||||
created_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
|
||||
|
||||
class UserOtp(Base):
|
||||
"""
|
||||
Models the UserOtp table, which contains the OTP secrets for each user.
|
||||
"""
|
||||
|
||||
__tablename__ = 'user_otp'
|
||||
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
|
||||
otp_secret = Column(String, nullable=False, unique=True)
|
||||
created_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
|
||||
|
||||
class UserBackupCode(Base):
|
||||
"""
|
||||
Models the UserBackupCode table, which contains the backup codes for each
|
||||
user with 2FA enabled.
|
||||
"""
|
||||
|
||||
__tablename__ = 'user_backup_code'
|
||||
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
|
||||
code = Column(String, nullable=False, unique=True)
|
||||
created_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
self.create_otp_secret(username)
|
||||
self.refresh_user_backup_codes(username)
|
||||
return True
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
99
platypush/user/_model.py
Normal file
99
platypush/user/_model.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
from dataclasses import dataclass, field
|
||||
import datetime
|
||||
import enum
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
|
||||
from platypush.common.db import Base
|
||||
|
||||
|
||||
class AuthenticationStatus(enum.Enum):
|
||||
"""
|
||||
Enum for authentication errors.
|
||||
"""
|
||||
|
||||
OK = ''
|
||||
INVALID_AUTH_TYPE = 'invalid_auth_type'
|
||||
INVALID_CREDENTIALS = 'invalid_credentials'
|
||||
INVALID_METHOD = 'invalid_method'
|
||||
INVALID_JWT_TOKEN = 'invalid_jwt_token'
|
||||
INVALID_OTP_CODE = 'invalid_otp_code'
|
||||
MISSING_OTP_CODE = 'missing_otp_code'
|
||||
MISSING_PASSWORD = 'missing_password'
|
||||
MISSING_USERNAME = 'missing_username'
|
||||
PASSWORD_MISMATCH = 'password_mismatch'
|
||||
REGISTRATION_DISABLED = 'registration_disabled'
|
||||
REGISTRATION_REQUIRED = 'registration_required'
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Models the User table"""
|
||||
|
||||
__tablename__ = 'user'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
user_id = Column(Integer, primary_key=True)
|
||||
username = Column(String, unique=True, nullable=False)
|
||||
password = Column(String)
|
||||
password_salt = Column(String)
|
||||
hmac_iterations = Column(Integer)
|
||||
created_at = Column(DateTime)
|
||||
|
||||
|
||||
class UserSession(Base):
|
||||
"""Models the UserSession table"""
|
||||
|
||||
__tablename__ = 'user_session'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
session_id = Column(Integer, primary_key=True)
|
||||
session_token = Column(String, unique=True, nullable=False)
|
||||
csrf_token = Column(String, unique=True)
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), nullable=False)
|
||||
created_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
|
||||
|
||||
class UserOtp(Base):
|
||||
"""
|
||||
Models the UserOtp table, which contains the OTP secrets for each user.
|
||||
"""
|
||||
|
||||
__tablename__ = 'user_otp'
|
||||
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
|
||||
otp_secret = Column(String, nullable=False, unique=True)
|
||||
created_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
|
||||
|
||||
class UserBackupCode(Base):
|
||||
"""
|
||||
Models the UserBackupCode table, which contains the backup codes for each
|
||||
user with 2FA enabled.
|
||||
"""
|
||||
|
||||
__tablename__ = 'user_backup_code'
|
||||
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
|
||||
code = Column(String, nullable=False, unique=True)
|
||||
created_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserResponse:
|
||||
"""
|
||||
Dataclass containing full information about a user (minus the password).
|
||||
"""
|
||||
|
||||
user_id: int
|
||||
username: str
|
||||
otp_secret: Optional[str] = None
|
||||
session_token: Optional[str] = None
|
||||
created_at: Optional[datetime.datetime] = None
|
||||
backup_codes: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
Loading…
Reference in a new issue