import datetime import hashlib import random import time from typing import Optional, Dict import bcrypt try: from jwt.exceptions import PyJWTError from jwt import encode as jwt_encode, decode as jwt_decode except ImportError: from jwt import PyJWTError, encode as jwt_encode, decode as jwt_decode from sqlalchemy import Column, Integer, String, DateTime, ForeignKey from sqlalchemy.orm import make_transient from sqlalchemy.ext.declarative import declarative_base from platypush.context import get_plugin from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException from platypush.utils import get_or_generate_jwt_rsa_key_pair Base = declarative_base() class UserManager: """ Main class for managing platform users """ def __init__(self): self.db = get_plugin('db') assert self.db self._engine = self.db.get_engine() self.db.create_all(self._engine, Base) @staticmethod def _mask_password(user): make_transient(user) user.password = None return user def get_user(self, username): with self.db.get_session() as session: user = self._get_user(session, username) if not user: return None session.expunge(user) return self._mask_password(user) def get_user_count(self): with self.db.get_session() as session: return session.query(User).count() def get_users(self): with self.db.get_session() as session: return session.query(User) def create_user(self, username, password, **kwargs): if not username: raise ValueError('Invalid or empty username') if not password: raise ValueError('Please provide a password for the user') with self.db.get_session() as session: user = self._get_user(session, username) if user: raise NameError('The user {} already exists'.format(username)) record = User(username=username, password=self._encrypt_password(password), created_at=datetime.datetime.utcnow(), **kwargs) session.add(record) session.commit() user = self._get_user(session, username) return self._mask_password(user) def update_password(self, username, old_password, new_password): with self.db.get_session() as session: if not self._authenticate_user(session, username, old_password): return False user = self._get_user(session, username) user.password = self._encrypt_password(new_password) session.commit() return True def authenticate_user(self, username, password): with self.db.get_session() as session: return self._authenticate_user(session, username, password) def authenticate_user_session(self, session_token): with self.db.get_session() as session: user_session = session.query(UserSession).filter_by(session_token=session_token).first() if not user_session or ( user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()): return None, None user = session.query(User).filter_by(user_id=user_session.user_id).first() return self._mask_password(user), user_session def delete_user(self, username): with self.db.get_session() as session: user = self._get_user(session, username) if not user: raise NameError('No such user: {}'.format(username)) user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all() for user_session in user_sessions: session.delete(user_session) session.delete(user) session.commit() return True def delete_user_session(self, session_token): with self.db.get_session() as session: user_session = session.query(UserSession).filter_by(session_token=session_token).first() if not user_session: return False session.delete(user_session) session.commit() return True def create_user_session(self, username, password, expires_at=None): with self.db.get_session() as session: user = self._authenticate_user(session, username, password) if not user: return None if expires_at: if isinstance(expires_at, int) or isinstance(expires_at, float): expires_at = datetime.datetime.fromtimestamp(expires_at) elif isinstance(expires_at, str): expires_at = datetime.datetime.fromisoformat(expires_at) user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(), csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(), expires_at=expires_at) session.add(user_session) session.commit() return user_session @staticmethod def _get_user(session, username): return session.query(User).filter_by(username=username).first() @staticmethod def _encrypt_password(pwd): if isinstance(pwd, str): pwd = pwd.encode() return bcrypt.hashpw(pwd, bcrypt.gensalt(12)).decode() @classmethod def _check_password(cls, pwd, hashed_pwd): return bcrypt.checkpw(cls._to_bytes(pwd), cls._to_bytes(hashed_pwd)) @staticmethod def _to_bytes(data) -> bytes: if isinstance(data, str): data = data.encode() return data @staticmethod def generate_session_token(): rand = bytes(random.randint(0, 255) for _ in range(0, 255)) return hashlib.sha256(rand).hexdigest() def get_user_by_session(self, session_token: str): """ Get a user associated to a session token. :param session_token: Session token. """ with self.db.get_session() as session: return session.query(User).join(UserSession).filter_by(session_token=session_token).first() def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str: """ Create a user JWT token for API usage. :param username: User name. :param password: Password. :param expires_at: Expiration datetime of the token. :return: The generated JWT token as a string. :raises: :class:`platypush.exceptions.user.InvalidCredentialsException` in case of invalid credentials. """ user = self.authenticate_user(username, password) if not user: raise InvalidCredentialsException() pub_key, priv_key = get_or_generate_jwt_rsa_key_pair() payload = { 'username': username, 'created_at': datetime.datetime.now().timestamp(), 'expires_at': expires_at.timestamp() if expires_at else None, } token = jwt_encode(payload, priv_key, algorithm='RS256') if isinstance(token, bytes): token = token.decode() return token @staticmethod def validate_jwt_token(token: str) -> Dict[str, str]: """ Validate a JWT token. :param token: Token to validate. :return: On success, it returns the JWT payload with the following structure: .. code-block:: json { "username": "user ID/name", "created_at": "token creation timestamp", "expires_at": "token expiration timestamp" } :raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token. """ pub_key, priv_key = get_or_generate_jwt_rsa_key_pair() try: payload = jwt_decode(token.encode(), pub_key, algorithms=['RS256']) except PyJWTError as e: raise InvalidJWTTokenException(str(e)) expires_at = payload.get('expires_at') if expires_at and time.time() > expires_at: raise InvalidJWTTokenException('Expired JWT token') return payload def _authenticate_user(self, session, username, password): """ :return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise. """ user = self._get_user(session, username) if not user: return None if not self._check_password(password, user.password): return None return user class User(Base): """ Models the User table """ __tablename__ = 'user' __table_args__ = ({'sqlite_autoincrement': True}) user_id = Column(Integer, primary_key=True) username = Column(String, unique=True, nullable=False) password = Column(String) created_at = Column(DateTime) class UserSession(Base): """ Models the UserSession table """ __tablename__ = 'user_session' __table_args__ = ({'sqlite_autoincrement': True}) session_id = Column(Integer, primary_key=True) session_token = Column(String, unique=True, nullable=False) csrf_token = Column(String, unique=True) user_id = Column(Integer, ForeignKey('user.user_id'), nullable=False) created_at = Column(DateTime) expires_at = Column(DateTime) # vim:sw=4:ts=4:et: