import base64 import datetime import hashlib import json import random import rsa import time from typing import Optional, Dict import bcrypt from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.ext.declarative import declarative_base from platypush.context import get_plugin from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException from platypush.utils import get_or_generate_jwt_rsa_key_pair Base = declarative_base() class UserManager: """ Main class for managing platform users """ # noinspection PyProtectedMember def __init__(self): db_plugin = get_plugin('db') if not db_plugin: raise ModuleNotFoundError('Please enable/configure the db plugin for multi-user support') self._engine = db_plugin._get_engine() def get_user(self, username): session = self._get_db_session() user = self._get_user(session, username) if not user: return None # Hide password user.password = None return user def get_user_count(self): session = self._get_db_session() return session.query(User).count() def get_users(self): session = self._get_db_session() return session.query(User) def create_user(self, username, password, **kwargs): session = self._get_db_session() if not username: raise ValueError('Invalid or empty username') if not password: raise ValueError('Please provide a password for the user') user = self._get_user(session, username) if user: raise NameError('The user {} already exists'.format(username)) record = User(username=username, password=self._encrypt_password(password), created_at=datetime.datetime.utcnow(), **kwargs) session.add(record) session.commit() user = self._get_user(session, username) # Hide password user.password = None return user def update_password(self, username, old_password, new_password): session = self._get_db_session() if not self._authenticate_user(session, username, old_password): return False user = self._get_user(session, username) user.password = self._encrypt_password(new_password) session.commit() return True def authenticate_user(self, username, password): session = self._get_db_session() return self._authenticate_user(session, username, password) def authenticate_user_session(self, session_token): session = self._get_db_session() user_session = session.query(UserSession).filter_by(session_token=session_token).first() if not user_session or ( user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()): return None, None user = session.query(User).filter_by(user_id=user_session.user_id).first() # Hide password user.password = None return user, session def delete_user(self, username): session = self._get_db_session() user = self._get_user(session, username) if not user: raise NameError('No such user: {}'.format(username)) user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all() for user_session in user_sessions: session.delete(user_session) session.delete(user) session.commit() return True def delete_user_session(self, session_token): session = self._get_db_session() user_session = session.query(UserSession).filter_by(session_token=session_token).first() if not user_session: return False session.delete(user_session) session.commit() return True def create_user_session(self, username, password, expires_at=None): session = self._get_db_session() user = self._authenticate_user(session, username, password) if not user: return None if expires_at: if isinstance(expires_at, int) or isinstance(expires_at, float): expires_at = datetime.datetime.fromtimestamp(expires_at) elif isinstance(expires_at, str): expires_at = datetime.datetime.fromisoformat(expires_at) user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(), csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(), expires_at=expires_at) session.add(user_session) session.commit() return user_session @staticmethod def _get_user(session, username): return session.query(User).filter_by(username=username).first() @staticmethod def _encrypt_password(pwd): if isinstance(pwd, str): pwd = pwd.encode() return bcrypt.hashpw(pwd, bcrypt.gensalt(12)).decode() @classmethod def _check_password(cls, pwd, hashed_pwd): return bcrypt.checkpw(cls._to_bytes(pwd), cls._to_bytes(hashed_pwd)) @staticmethod def _to_bytes(data) -> bytes: if isinstance(data, str): data = data.encode() return data @staticmethod def generate_session_token(): rand = bytes(random.randint(0, 255) for _ in range(0, 255)) return hashlib.sha256(rand).hexdigest() def get_user_by_session(self, session_token: str): """ Get a user associated to a session token. :param session_token: Session token. """ session = self._get_db_session() return session.query(User).join(UserSession).filter_by(session_token=session_token).first() def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str: """ Create a user JWT token for API usage. :param username: User name. :param password: Password. :param expires_at: Expiration datetime of the token. :return: The generated JWT token as a string. :raises: :class:`platypush.exceptions.user.InvalidCredentialsException` in case of invalid credentials. """ user = self.authenticate_user(username, password) if not user: raise InvalidCredentialsException() pub_key, _ = get_or_generate_jwt_rsa_key_pair() payload = json.dumps( { 'username': username, 'created_at': datetime.datetime.now().timestamp(), 'expires_at': expires_at.timestamp() if expires_at else None, }, sort_keys=True, indent=None, ) return base64.b64encode( rsa.encrypt(payload.encode('ascii'), pub_key) ).decode() @staticmethod def validate_jwt_token(token: str) -> Dict[str, str]: """ Validate a JWT token. :param token: Token to validate. :return: On success, it returns the JWT payload with the following structure: .. code-block:: json { "username": "user ID/name", "created_at": "token creation timestamp", "expires_at": "token expiration timestamp" } :raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token. """ _, priv_key = get_or_generate_jwt_rsa_key_pair() try: payload = json.loads( rsa.decrypt( base64.b64decode(token.encode('ascii')), priv_key ).decode('ascii') ) except (TypeError, ValueError) as e: raise InvalidJWTTokenException(f'Could not decode JWT token: {e}') expires_at = payload.get('expires_at') if expires_at and time.time() > expires_at: raise InvalidJWTTokenException('Expired JWT token') return payload def _get_db_session(self): Base.metadata.create_all(self._engine) session = scoped_session(sessionmaker(expire_on_commit=False)) session.configure(bind=self._engine) return session() def _authenticate_user(self, session, username, password): """ :return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise. """ user = self._get_user(session, username) if not user: return None if not self._check_password(password, user.password): return None return user class User(Base): """ Models the User table """ __tablename__ = 'user' __table_args__ = ({'sqlite_autoincrement': True}) user_id = Column(Integer, primary_key=True) username = Column(String, unique=True, nullable=False) password = Column(String) created_at = Column(DateTime) class UserSession(Base): """ Models the UserSession table """ __tablename__ = 'user_session' __table_args__ = ({'sqlite_autoincrement': True}) session_id = Column(Integer, primary_key=True) session_token = Column(String, unique=True, nullable=False) csrf_token = Column(String, unique=True) user_id = Column(Integer, ForeignKey('user.user_id'), nullable=False) created_at = Column(DateTime) expires_at = Column(DateTime) # vim:sw=4:ts=4:et: