platypush/platypush/user/__init__.py

281 lines
9 KiB
Python

import datetime
import hashlib
import random
import time
from typing import Optional, Dict
import bcrypt
import jwt
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base
from platypush.context import get_plugin
from platypush.exceptions.user import InvalidJWTTokenException, InvalidCredentialsException
from platypush.utils import get_or_generate_jwt_rsa_key_pair
Base = declarative_base()
class UserManager:
"""
Main class for managing platform users
"""
# noinspection PyProtectedMember
def __init__(self):
db_plugin = get_plugin('db')
if not db_plugin:
raise ModuleNotFoundError('Please enable/configure the db plugin for multi-user support')
self._engine = db_plugin._get_engine()
def get_user(self, username):
session = self._get_db_session()
user = self._get_user(session, username)
if not user:
return None
# Hide password
user.password = None
return user
def get_user_count(self):
session = self._get_db_session()
return session.query(User).count()
def get_users(self):
session = self._get_db_session()
return session.query(User)
def create_user(self, username, password, **kwargs):
session = self._get_db_session()
if not username:
raise ValueError('Invalid or empty username')
if not password:
raise ValueError('Please provide a password for the user')
user = self._get_user(session, username)
if user:
raise NameError('The user {} already exists'.format(username))
record = User(username=username, password=self._encrypt_password(password),
created_at=datetime.datetime.utcnow(), **kwargs)
session.add(record)
session.commit()
user = self._get_user(session, username)
# Hide password
user.password = None
return user
def update_password(self, username, old_password, new_password):
session = self._get_db_session()
if not self._authenticate_user(session, username, old_password):
return False
user = self._get_user(session, username)
user.password = self._encrypt_password(new_password)
session.commit()
return True
def authenticate_user(self, username, password):
session = self._get_db_session()
return self._authenticate_user(session, username, password)
def authenticate_user_session(self, session_token):
session = self._get_db_session()
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
if not user_session or (
user_session.expires_at and user_session.expires_at < datetime.datetime.utcnow()):
return None, None
user = session.query(User).filter_by(user_id=user_session.user_id).first()
# Hide password
user.password = None
return user, session
def delete_user(self, username):
session = self._get_db_session()
user = self._get_user(session, username)
if not user:
raise NameError('No such user: {}'.format(username))
user_sessions = session.query(UserSession).filter_by(user_id=user.user_id).all()
for user_session in user_sessions:
session.delete(user_session)
session.delete(user)
session.commit()
return True
def delete_user_session(self, session_token):
session = self._get_db_session()
user_session = session.query(UserSession).filter_by(session_token=session_token).first()
if not user_session:
return False
session.delete(user_session)
session.commit()
return True
def create_user_session(self, username, password, expires_at=None):
session = self._get_db_session()
user = self._authenticate_user(session, username, password)
if not user:
return None
if expires_at:
if isinstance(expires_at, int) or isinstance(expires_at, float):
expires_at = datetime.datetime.fromtimestamp(expires_at)
elif isinstance(expires_at, str):
expires_at = datetime.datetime.fromisoformat(expires_at)
user_session = UserSession(user_id=user.user_id, session_token=self.generate_session_token(),
csrf_token=self.generate_session_token(), created_at=datetime.datetime.utcnow(),
expires_at=expires_at)
session.add(user_session)
session.commit()
return user_session
@staticmethod
def _get_user(session, username):
return session.query(User).filter_by(username=username).first()
@staticmethod
def _encrypt_password(pwd):
return bcrypt.hashpw(pwd.encode(), bcrypt.gensalt(12))
@staticmethod
def _check_password(pwd, hashed_pwd):
return bcrypt.checkpw(pwd.encode(), hashed_pwd)
@staticmethod
def generate_session_token():
rand = bytes(random.randint(0, 255) for _ in range(0, 255))
return hashlib.sha256(rand).hexdigest()
def get_user_by_session(self, session_token: str):
"""
Get a user associated to a session token.
:param session_token: Session token.
"""
session = self._get_db_session()
return session.query(User).join(UserSession).filter_by(session_token=session_token).first()
if not user:
return None
return {
'user_id': user.user_id,
'username': user.username,
'created_at': user.created_at,
}
def generate_jwt_token(self, username: str, password: str, expires_at: Optional[datetime.datetime] = None) -> str:
"""
Create a user JWT token for API usage.
:param username: User name.
:param password: Password.
:param expires_at: Expiration datetime of the token.
:return: The generated JWT token as a string.
:raises: :class:`platypush.exceptions.user.InvalidCredentialsException` in case of invalid credentials.
"""
user = self.authenticate_user(username, password)
if not user:
raise InvalidCredentialsException()
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
payload = {
'username': username,
'created_at': datetime.datetime.now().timestamp(),
'expires_at': expires_at.timestamp() if expires_at else None,
}
return jwt.encode(payload, priv_key, algorithm='RS256').decode()
@staticmethod
def validate_jwt_token(token: str) -> Dict[str, str]:
"""
Validate a JWT token.
:param token: Token to validate.
:return: On success, it returns the JWT payload with the following structure:
.. code-block:: json
{
"username": "user ID/name",
"created_at": "token creation timestamp",
"expires_at": "token expiration timestamp"
}
:raises: :class:`platypush.exceptions.user.InvalidJWTTokenException` in case of invalid token.
"""
pub_key, priv_key = get_or_generate_jwt_rsa_key_pair()
try:
payload = jwt.decode(token.encode(), pub_key, algorithms=['RS256'])
except jwt.exceptions.PyJWTError as e:
raise InvalidJWTTokenException(str(e))
expires_at = payload.get('expires_at')
if expires_at and time.time() > expires_at:
raise InvalidJWTTokenException('Expired JWT token')
return payload
def _get_db_session(self):
Base.metadata.create_all(self._engine)
session = scoped_session(sessionmaker())
session.configure(bind=self._engine)
return session()
def _authenticate_user(self, session, username, password):
"""
:return: :class:`platypush.user.User` instance if the user exists and the password is valid, ``None`` otherwise.
"""
user = self._get_user(session, username)
if not user:
return None
if not self._check_password(password, user.password):
return None
return user
class User(Base):
""" Models the User table """
__tablename__ = 'user'
__table_args__ = ({'sqlite_autoincrement': True})
user_id = Column(Integer, primary_key=True)
username = Column(String, unique=True, nullable=False)
password = Column(String)
created_at = Column(DateTime)
class UserSession(Base):
""" Models the UserSession table """
__tablename__ = 'user_session'
__table_args__ = ({'sqlite_autoincrement': True})
session_id = Column(Integer, primary_key=True)
session_token = Column(String, unique=True, nullable=False)
csrf_token = Column(String, unique=True)
user_id = Column(Integer, ForeignKey('user.user_id'), nullable=False)
created_at = Column(DateTime)
expires_at = Column(DateTime)
# vim:sw=4:ts=4:et: