forked from platypush/platypush
[#339] Backend preparation for 2FA support.
This commit is contained in:
parent
2cbb005c67
commit
8ec1ca8543
6 changed files with 353 additions and 59 deletions
207
platypush/backend/http/app/routes/otp.py
Normal file
207
platypush/backend/http/app/routes/otp.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, request
|
||||
|
||||
from platypush.backend.http.app import template_folder
|
||||
from platypush.backend.http.app.utils import UserAuthStatus, authenticate
|
||||
from platypush.backend.http.utils import HttpUtils
|
||||
from platypush.exceptions.user import (
|
||||
InvalidCredentialsException,
|
||||
InvalidOtpCodeException,
|
||||
UserException,
|
||||
)
|
||||
from platypush.config import Config
|
||||
from platypush.context import get_plugin
|
||||
from platypush.user import UserManager
|
||||
|
||||
otp = Blueprint('otp', __name__, template_folder=template_folder)
|
||||
|
||||
# Declare routes list
|
||||
__routes__ = [
|
||||
otp,
|
||||
]
|
||||
|
||||
|
||||
def _get_otp_and_qrcode():
|
||||
otp = get_plugin('otp') # pylint: disable=redefined-outer-name
|
||||
qrcode = get_plugin('qrcode')
|
||||
assert (
|
||||
otp and qrcode
|
||||
), 'The otp and/or qrcode plugins are not available in your installation'
|
||||
|
||||
return otp, qrcode
|
||||
|
||||
|
||||
def _get_username():
|
||||
user = HttpUtils.current_user()
|
||||
if not user:
|
||||
raise InvalidCredentialsException('Invalid user session')
|
||||
|
||||
return str(user.username)
|
||||
|
||||
|
||||
def _get_otp_uri_and_qrcode(username: str, otp_secret: Optional[str] = None):
|
||||
if not otp_secret:
|
||||
return None, None
|
||||
|
||||
otp, qrcode = _get_otp_and_qrcode() # pylint: disable=redefined-outer-name
|
||||
otp_uri = (
|
||||
otp.provision_time_otp(
|
||||
name=username,
|
||||
secret=otp_secret,
|
||||
issuer=f'platypush@{Config.get_device_id()}',
|
||||
).output
|
||||
if otp_secret
|
||||
else None
|
||||
)
|
||||
|
||||
otp_qrcode = (
|
||||
qrcode.generate(content=otp_uri, format='png').output.get('data')
|
||||
if otp_uri
|
||||
else None
|
||||
)
|
||||
|
||||
return otp_uri, otp_qrcode
|
||||
|
||||
|
||||
def _verify_code(code: str, otp_secret: str) -> bool:
|
||||
otp, _ = _get_otp_and_qrcode() # pylint: disable=redefined-outer-name
|
||||
return otp.verify_time_otp(otp=code, secret=otp_secret).output
|
||||
|
||||
|
||||
def _dump_response(
|
||||
username: str,
|
||||
otp_secret: Optional[str] = None,
|
||||
backup_codes: Optional[List[str]] = None,
|
||||
):
|
||||
otp_uri, otp_qrcode = _get_otp_uri_and_qrcode(username, otp_secret)
|
||||
return jsonify(
|
||||
{
|
||||
'username': username,
|
||||
'otp_secret': otp_secret,
|
||||
'otp_uri': otp_uri,
|
||||
'qrcode': otp_qrcode,
|
||||
'backup_codes': backup_codes or [],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_otp():
|
||||
username = _get_username()
|
||||
user_manager = UserManager()
|
||||
otp_secret = user_manager.get_otp_secret(username)
|
||||
backup_codes = user_manager.get_user_backup_codes(username) if otp_secret else []
|
||||
return _dump_response(
|
||||
username=username,
|
||||
otp_secret=otp_secret,
|
||||
backup_codes=[str(c.code) for c in backup_codes],
|
||||
)
|
||||
|
||||
|
||||
def _authenticate_user(username: str, password: Optional[str]):
|
||||
assert password, 'The password field is required when setting up OTP'
|
||||
user, auth_status = UserManager().authenticate_user( # type: ignore
|
||||
username, password, skip_2fa=True, with_status=True
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise InvalidCredentialsException(auth_status.value[2])
|
||||
|
||||
|
||||
def _post_otp():
|
||||
body = request.json
|
||||
assert body, 'Invalid request body'
|
||||
|
||||
username = _get_username()
|
||||
dry_run = body.get('dry_run', False)
|
||||
otp_secret = body.get('otp_secret')
|
||||
|
||||
if not dry_run:
|
||||
_authenticate_user(username, body.get('password'))
|
||||
|
||||
if otp_secret:
|
||||
code = body.get('code')
|
||||
assert code, 'The code field is required when setting up OTP'
|
||||
|
||||
if not _verify_code(code, otp_secret):
|
||||
raise InvalidOtpCodeException()
|
||||
|
||||
user_manager = UserManager()
|
||||
user_otp, backup_codes = user_manager.enable_otp(
|
||||
username=username,
|
||||
otp_secret=otp_secret,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
return _dump_response(
|
||||
username=username,
|
||||
otp_secret=str(user_otp.otp_secret),
|
||||
backup_codes=backup_codes,
|
||||
)
|
||||
|
||||
|
||||
def _delete_otp():
|
||||
body = request.json
|
||||
assert body, 'Invalid request body'
|
||||
|
||||
username = _get_username()
|
||||
_authenticate_user(username, body.get('password'))
|
||||
|
||||
user_manager = UserManager()
|
||||
user_manager.disable_otp(username)
|
||||
return jsonify({'status': 'ok'})
|
||||
|
||||
|
||||
@otp.route('/otp/config', methods=['GET', 'POST', 'DELETE'])
|
||||
@authenticate()
|
||||
def otp_route():
|
||||
"""
|
||||
:return: The user's current MFA/OTP configuration:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"username": "testuser",
|
||||
"otp_secret": "JBSA6ZUZ5DPEK7YV",
|
||||
"otp_uri": "otpauth://totp/testuser?secret=JBSA6ZUZ5DPEK7YV&issuer=platypush@localhost",
|
||||
"qrcode": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAMgAAADICAYAAACtWK6eAAABwklEQVR4nO3dMW7CQBAF0",
|
||||
"backup_codes": [
|
||||
"1A2B3C4D5E",
|
||||
"6F7G8H9I0J",
|
||||
"KLMNOPQRST",
|
||||
"UVWXYZ1234",
|
||||
"567890ABCD",
|
||||
"EFGHIJKLMN",
|
||||
"OPQRSTUVWX",
|
||||
"YZ12345678",
|
||||
"90ABCDEF12",
|
||||
"34567890AB"
|
||||
]
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
if request.method.lower() == 'get':
|
||||
return _get_otp()
|
||||
|
||||
if request.method.lower() == 'post':
|
||||
return _post_otp()
|
||||
|
||||
if request.method.lower() == 'delete':
|
||||
return _delete_otp()
|
||||
|
||||
return jsonify({'error': 'Method not allowed'}), 405
|
||||
except AssertionError as e:
|
||||
return jsonify({'error': str(e)}), 400
|
||||
except InvalidCredentialsException:
|
||||
return UserAuthStatus.INVALID_CREDENTIALS.to_response()
|
||||
except InvalidOtpCodeException:
|
||||
return UserAuthStatus.INVALID_OTP_CODE.to_response()
|
||||
except UserException as e:
|
||||
return jsonify({'error': e.__class__.__name__, 'message': str(e)}), 401
|
||||
except Exception as e:
|
||||
HttpUtils.log.error(f'Error while processing OTP request: {e}', exc_info=True)
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
|
@ -7,6 +7,7 @@ class UserException(PlatypushException):
|
|||
"""
|
||||
Base class for all user exceptions.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, user: Optional[Union[str, int]] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.user = user
|
||||
|
@ -16,6 +17,7 @@ class AuthenticationException(UserException):
|
|||
"""
|
||||
Authentication error exception.
|
||||
"""
|
||||
|
||||
def __init__(self, error='Unauthorized', *args, **kwargs):
|
||||
super().__init__(error, *args, **kwargs)
|
||||
|
||||
|
@ -24,6 +26,7 @@ class InvalidTokenException(AuthenticationException):
|
|||
"""
|
||||
Exception raised in case of wrong user token.
|
||||
"""
|
||||
|
||||
def __init__(self, error='Invalid user token', *args, **kwargs):
|
||||
super().__init__(error, *args, **kwargs)
|
||||
|
||||
|
@ -32,6 +35,7 @@ class InvalidCredentialsException(AuthenticationException):
|
|||
"""
|
||||
Exception raised in case of wrong user token.
|
||||
"""
|
||||
|
||||
def __init__(self, error='Invalid credentials', *args, **kwargs):
|
||||
super().__init__(error, *args, **kwargs)
|
||||
|
||||
|
@ -40,5 +44,26 @@ class InvalidJWTTokenException(InvalidTokenException):
|
|||
"""
|
||||
Exception raised in case of wrong/invalid JWT token.
|
||||
"""
|
||||
|
||||
def __init__(self, error='Invalid JWT token', *args, **kwargs):
|
||||
super().__init__(error, *args, **kwargs)
|
||||
|
||||
|
||||
class InvalidOtpCodeException(AuthenticationException):
|
||||
"""
|
||||
Exception raised in case of wrong OTP code.
|
||||
"""
|
||||
|
||||
def __init__(self, error='Invalid OTP code', *args, **kwargs):
|
||||
super().__init__(error, *args, **kwargs)
|
||||
|
||||
|
||||
class OtpRecordAlreadyExistsException(UserException):
|
||||
"""
|
||||
Exception raised in case of an OTP record already existing for a user.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args, error='An OTP record already exists for this user', **kwargs
|
||||
):
|
||||
super().__init__(*args, error, **kwargs)
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Optional
|
|||
|
||||
import pyotp
|
||||
|
||||
from platypush import Response
|
||||
from platypush.config import Config
|
||||
from platypush.plugins import Plugin, action
|
||||
|
||||
|
@ -19,7 +18,7 @@ class OtpPlugin(Plugin):
|
|||
secret: Optional[str] = None,
|
||||
secret_path: Optional[str] = None,
|
||||
provisioning_name: Optional[str] = None,
|
||||
issuer_name: Optional[str] = None,
|
||||
issuer: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
|
@ -29,17 +28,17 @@ class OtpPlugin(Plugin):
|
|||
generated.
|
||||
:param provisioning_name: If you want to use the Google Authenticator, you can specify the default
|
||||
email address to associate to your OTPs for the provisioning process here.
|
||||
:param issuer_name: If you want to use the Google Authenticator, you can specify the default
|
||||
:param issuer: If you want to use the Google Authenticator, you can specify the default
|
||||
issuer name to display on your OTPs here.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if not secret_path:
|
||||
secret_path = os.path.join(Config.get('workdir'), 'otp', 'secret')
|
||||
secret_path = os.path.join(Config.get_workdir(), 'otp', 'secret')
|
||||
|
||||
self.secret_path = secret_path
|
||||
self.secret = secret
|
||||
self.provisioning_name = provisioning_name
|
||||
self.issuer_name = issuer_name
|
||||
self.issuer = issuer
|
||||
|
||||
def _get_secret_from_path(self, secret_path: str) -> str:
|
||||
if not os.path.isfile(secret_path):
|
||||
|
@ -75,7 +74,16 @@ class OtpPlugin(Plugin):
|
|||
return pyotp.HOTP(self._get_secret(secret, secret_path))
|
||||
|
||||
@action
|
||||
def refresh_secret(self, secret_path: Optional[str] = None) -> Response:
|
||||
def generate_secret(self) -> str:
|
||||
"""
|
||||
Generate a new secret token for key generation.
|
||||
|
||||
:return: The new secret token.
|
||||
"""
|
||||
return pyotp.random_base32()
|
||||
|
||||
@action
|
||||
def refresh_secret(self, secret_path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Refresh the secret token for key generation given a secret path.
|
||||
|
||||
|
@ -162,7 +170,7 @@ class OtpPlugin(Plugin):
|
|||
def provision_time_otp(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
issuer_name: Optional[str] = None,
|
||||
issuer: Optional[str] = None,
|
||||
secret: Optional[str] = None,
|
||||
secret_path: Optional[str] = None,
|
||||
) -> str:
|
||||
|
@ -171,23 +179,23 @@ class OtpPlugin(Plugin):
|
|||
|
||||
:param name: Name or e-mail address associated to the account used by the Google Authenticator.
|
||||
If None is specified then the value will be read from the configured ``provisioning_name``.
|
||||
:param issuer_name: Name of the issuer of the OTP (default: default configured ``issuer_name`` or None).
|
||||
:param issuer: Name of the issuer of the OTP (default: default configured ``issuer`` or None).
|
||||
:param secret: Secret token to be used (overrides configured ``secret``).
|
||||
:param secret_path: File containing the secret to be used (overrides configured ``secret_path``).
|
||||
:return: Generated provisioning URI.
|
||||
"""
|
||||
name = name or self.provisioning_name
|
||||
issuer_name = issuer_name or self.issuer_name
|
||||
issuer = issuer or self.issuer
|
||||
assert name, 'No account name or default provisioning address provided'
|
||||
|
||||
_otp = self._get_topt(secret, secret_path)
|
||||
return _otp.provisioning_uri(name, issuer_name=issuer_name)
|
||||
return _otp.provisioning_uri(name, issuer_name=issuer)
|
||||
|
||||
@action
|
||||
def provision_counter_otp(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
issuer_name: Optional[str] = None,
|
||||
issuer: Optional[str] = None,
|
||||
initial_count=0,
|
||||
secret: Optional[str] = None,
|
||||
secret_path: Optional[str] = None,
|
||||
|
@ -197,19 +205,19 @@ class OtpPlugin(Plugin):
|
|||
|
||||
:param name: Name or e-mail address associated to the account used by the Google Authenticator.
|
||||
If None is specified then the value will be read from the configured ``provisioning_name``.
|
||||
:param issuer_name: Name of the issuer of the OTP (default: default configured ``issuer_name`` or None).
|
||||
:param issuer: Name of the issuer of the OTP (default: default configured ``issuer`` or None).
|
||||
:param initial_count: Initial value for the counter (default: 0).
|
||||
:param secret: Secret token to be used (overrides configured ``secret``).
|
||||
:param secret_path: File containing the secret to be used (overrides configured ``secret_path``).
|
||||
:return: Generated provisioning URI.
|
||||
"""
|
||||
name = name or self.provisioning_name
|
||||
issuer_name = issuer_name or self.issuer_name
|
||||
issuer = issuer or self.issuer
|
||||
assert name, 'No account name or default provisioning address provided'
|
||||
|
||||
_otp = self._get_hopt(secret, secret_path)
|
||||
return _otp.provisioning_uri(
|
||||
name, issuer_name=issuer_name, initial_count=initial_count
|
||||
name, issuer_name=issuer, initial_count=initial_count
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ class UserPlugin(Plugin):
|
|||
:param username: Username.
|
||||
:param password: Password.
|
||||
:param code: Optional 2FA code, if 2FA is enabled for the user.
|
||||
:param return_error_details: If True then return the error details in
|
||||
:param return_details: If True then return the error details in
|
||||
case of authentication failure.
|
||||
:return: If ``return_details`` is False (default), the action returns
|
||||
True if the provided credentials are valid, False otherwise.
|
||||
|
@ -96,7 +96,7 @@ class UserPlugin(Plugin):
|
|||
|
||||
"""
|
||||
response = self.user_manager.authenticate_user(
|
||||
username, password, code=code, return_error=return_details
|
||||
username, password, code=code, with_status=return_details
|
||||
)
|
||||
|
||||
if return_details:
|
||||
|
|
|
@ -17,6 +17,7 @@ from platypush.context import get_plugin
|
|||
from platypush.exceptions.user import (
|
||||
InvalidJWTTokenException,
|
||||
InvalidCredentialsException,
|
||||
OtpRecordAlreadyExistsException,
|
||||
)
|
||||
from platypush.utils import get_or_generate_stored_rsa_key_pair, utcnow
|
||||
|
||||
|
@ -63,7 +64,7 @@ class UserManager:
|
|||
"""
|
||||
Get or generate the OTP RSA key pair.
|
||||
"""
|
||||
return get_or_generate_stored_rsa_key_pair(cls._otp_keyfile, size=4096)
|
||||
return get_or_generate_stored_rsa_key_pair(cls._otp_keyfile, size=2048)
|
||||
|
||||
@staticmethod
|
||||
def _encrypt(data: Union[str, bytes, dict, list, tuple], key: rsa.PublicKey) -> str:
|
||||
|
@ -151,10 +152,17 @@ class UserManager:
|
|||
session.commit()
|
||||
return True
|
||||
|
||||
def authenticate_user(self, username, password, code=None, return_error=False):
|
||||
def authenticate_user(
|
||||
self, username, password, code=None, skip_2fa=False, with_status=False
|
||||
):
|
||||
with self._get_session() as session:
|
||||
return self._authenticate_user(
|
||||
session, username, password, code=code, with_status=return_error
|
||||
session,
|
||||
username,
|
||||
password,
|
||||
code=code,
|
||||
skip_2fa=skip_2fa,
|
||||
with_status=with_status,
|
||||
)
|
||||
|
||||
def authenticate_user_session(self, session_token, with_status=False):
|
||||
|
@ -268,32 +276,44 @@ class UserManager:
|
|||
)
|
||||
|
||||
def create_otp_secret(
|
||||
self, username: str, expires_at: Optional[datetime.datetime] = None
|
||||
self,
|
||||
username: str,
|
||||
expires_at: Optional[datetime.datetime] = None,
|
||||
otp_secret: Optional[str] = None,
|
||||
dry_run: bool = False,
|
||||
):
|
||||
pubkey, _ = self._get_or_generate_otp_rsa_key_pair()
|
||||
|
||||
# Generate a new OTP secret and encrypt it with the OTP RSA key pair
|
||||
otp_secret = "".join(
|
||||
otp_secret = otp_secret or "".join(
|
||||
random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(32)
|
||||
)
|
||||
|
||||
encrypted_secret = self._encrypt(otp_secret, pubkey)
|
||||
|
||||
with self._get_session(locked=True) as session:
|
||||
user = self._get_user(session, username)
|
||||
assert user, f'No such user: {username}'
|
||||
if not user:
|
||||
raise InvalidCredentialsException()
|
||||
|
||||
# Create a new OTP secret
|
||||
user_otp = UserOtp(
|
||||
user_id=user.user_id,
|
||||
otp_secret=encrypted_secret,
|
||||
otp_secret=otp_secret,
|
||||
created_at=utcnow(),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if not dry_run:
|
||||
# Store a copy of the OTP secret encrypted with the RSA public key
|
||||
pubkey, _ = self._get_or_generate_otp_rsa_key_pair()
|
||||
encrypted_secret = self._encrypt(otp_secret, pubkey)
|
||||
encrypted_otp = UserOtp(
|
||||
user_id=user_otp.user_id,
|
||||
otp_secret=encrypted_secret,
|
||||
created_at=user_otp.created_at,
|
||||
expires_at=user_otp.expires_at,
|
||||
)
|
||||
|
||||
# Remove any existing OTP secret and replace it with the new one
|
||||
session.query(UserOtp).filter_by(user_id=user.user_id).delete()
|
||||
session.add(user_otp)
|
||||
session.add(encrypted_otp)
|
||||
session.commit()
|
||||
|
||||
return user_otp
|
||||
|
@ -445,6 +465,7 @@ class UserManager:
|
|||
username: str,
|
||||
password: str,
|
||||
code: Optional[str] = None,
|
||||
skip_2fa: bool = False,
|
||||
with_status: bool = False,
|
||||
) -> Union[Optional['User'], Tuple[Optional['User'], 'AuthenticationStatus']]:
|
||||
"""
|
||||
|
@ -478,7 +499,7 @@ class UserManager:
|
|||
|
||||
# The user doesn't have 2FA enabled and the password is correct:
|
||||
# authentication successful
|
||||
if not otp_secret:
|
||||
if skip_2fa or not otp_secret:
|
||||
return user if not with_status else (user, AuthenticationStatus.OK)
|
||||
|
||||
# The user has 2FA enabled but the code is missing
|
||||
|
@ -501,23 +522,27 @@ class UserManager:
|
|||
|
||||
return user if not with_status else (user, AuthenticationStatus.OK)
|
||||
|
||||
def refresh_user_backup_codes(self, username: str):
|
||||
def refresh_user_backup_codes(self, username: str) -> List[str]:
|
||||
"""
|
||||
Refresh the backup codes for a user with 2FA enabled.
|
||||
"""
|
||||
backup_codes = [
|
||||
"".join(
|
||||
random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(16)
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
with self._get_session(locked=True) as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return False
|
||||
return []
|
||||
|
||||
session.query(UserBackupCode).filter_by(user_id=user.user_id).delete()
|
||||
pub_key, _ = self._get_or_generate_otp_rsa_key_pair()
|
||||
stored_codes = []
|
||||
|
||||
for _ in range(10):
|
||||
backup_code = "".join(
|
||||
random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") for _ in range(16)
|
||||
)
|
||||
|
||||
for backup_code in backup_codes:
|
||||
user_backup_code = UserBackupCode(
|
||||
user_id=user.user_id,
|
||||
code=self._encrypt(backup_code, pub_key),
|
||||
|
@ -526,9 +551,10 @@ class UserManager:
|
|||
)
|
||||
|
||||
session.add(user_backup_code)
|
||||
stored_codes.append(backup_code)
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
return stored_codes
|
||||
|
||||
def get_user_backup_codes(self, username: str) -> List['UserBackupCode']:
|
||||
with self._get_session() as session:
|
||||
|
@ -537,12 +563,17 @@ class UserManager:
|
|||
return []
|
||||
|
||||
_, priv_key = self._get_or_generate_otp_rsa_key_pair()
|
||||
codes = session.query(UserBackupCode).filter_by(user_id=user.user_id).all()
|
||||
|
||||
for code in codes:
|
||||
code.code = self._decrypt(code.code, priv_key)
|
||||
|
||||
return codes
|
||||
return [
|
||||
UserBackupCode(
|
||||
user_id=code.user_id,
|
||||
code=self._decrypt(code.code, priv_key),
|
||||
created_at=code.created_at,
|
||||
expires_at=code.expires_at,
|
||||
)
|
||||
for code in session.query(UserBackupCode)
|
||||
.filter_by(user_id=user.user_id)
|
||||
.all()
|
||||
]
|
||||
|
||||
def validate_backup_code(self, username: str, code: str) -> bool:
|
||||
with self._get_session() as session:
|
||||
|
@ -583,7 +614,7 @@ class UserManager:
|
|||
|
||||
return otp.verify_time_otp(otp=code, secret=otp_secret)
|
||||
|
||||
def disable_mfa(self, username: str):
|
||||
def disable_otp(self, username: str):
|
||||
with self._get_session(locked=True) as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
|
@ -594,15 +625,30 @@ class UserManager:
|
|||
session.commit()
|
||||
return True
|
||||
|
||||
def enable_mfa(self, username: str):
|
||||
def enable_otp(
|
||||
self,
|
||||
username: str,
|
||||
dry_run: bool = False,
|
||||
otp_secret: Optional[str] = None,
|
||||
):
|
||||
with self._get_session() as session:
|
||||
user = self._get_user(session, username)
|
||||
if not user:
|
||||
return False
|
||||
raise InvalidCredentialsException()
|
||||
|
||||
self.create_otp_secret(username)
|
||||
self.refresh_user_backup_codes(username)
|
||||
return True
|
||||
user_otp = session.query(UserOtp).filter_by(user_id=user.user_id).first()
|
||||
if user_otp:
|
||||
raise OtpRecordAlreadyExistsException()
|
||||
|
||||
user_otp = self.create_otp_secret(
|
||||
username, otp_secret=otp_secret, dry_run=dry_run
|
||||
)
|
||||
|
||||
backup_codes = (
|
||||
self.refresh_user_backup_codes(username) if not dry_run else []
|
||||
)
|
||||
|
||||
return user_otp, backup_codes
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
import enum
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
)
|
||||
|
||||
from platypush.common.db import Base
|
||||
|
||||
|
@ -28,9 +35,8 @@ class User(Base):
|
|||
"""Models the User table"""
|
||||
|
||||
__tablename__ = 'user'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
user_id = Column(Integer, primary_key=True)
|
||||
user_id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
username = Column(String, unique=True, nullable=False)
|
||||
password = Column(String)
|
||||
password_salt = Column(String)
|
||||
|
@ -42,9 +48,8 @@ class UserSession(Base):
|
|||
"""Models the UserSession table"""
|
||||
|
||||
__tablename__ = 'user_session'
|
||||
__table_args__ = {'sqlite_autoincrement': True}
|
||||
|
||||
session_id = Column(Integer, primary_key=True)
|
||||
session_id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
session_token = Column(String, unique=True, nullable=False)
|
||||
csrf_token = Column(String, unique=True)
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), nullable=False)
|
||||
|
@ -73,10 +78,13 @@ class UserBackupCode(Base):
|
|||
|
||||
__tablename__ = 'user_backup_code'
|
||||
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'), primary_key=True)
|
||||
code = Column(String, nullable=False, unique=True)
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('user.user_id'))
|
||||
code = Column(String, nullable=False)
|
||||
created_at = Column(DateTime)
|
||||
expires_at = Column(DateTime)
|
||||
|
||||
UniqueConstraint(user_id, code)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
||||
|
|
Loading…
Reference in a new issue