diff --git a/platypush/backend/http/__init__.py b/platypush/backend/http/__init__.py index 782bec3e80..e298b204ed 100644 --- a/platypush/backend/http/__init__.py +++ b/platypush/backend/http/__init__.py @@ -12,8 +12,8 @@ from tornado.ioloop import IOLoop from platypush.backend import Backend from platypush.backend.http.app import application -from platypush.backend.http.ws import scan_routes -from platypush.backend.http.ws.events import events_redis_topic +from platypush.backend.http.app.utils import get_ws_routes +from platypush.backend.http.app.ws.events import events_redis_topic from platypush.bus.redis import RedisBus from platypush.config import Config @@ -264,7 +264,7 @@ class HttpBackend(Backend): container = WSGIContainer(application) server = Application( [ - *[(route.path(), route) for route in scan_routes()], + *[(route.path(), route) for route in get_ws_routes()], (r'.*', FallbackHandler, {'fallback': container}), ] ) diff --git a/platypush/backend/http/app/utils.py b/platypush/backend/http/app/utils.py deleted file mode 100644 index 537775045c..0000000000 --- a/platypush/backend/http/app/utils.py +++ /dev/null @@ -1,324 +0,0 @@ -import importlib -import logging -import os - -from functools import wraps -from flask import abort, request, redirect, jsonify, current_app -from flask.wrappers import Response -from redis import Redis - -# NOTE: The HTTP service will *only* work on top of a Redis bus. The default -# internal bus service won't work as the web server will run in a different process. -from platypush.bus.redis import RedisBus - -from platypush.config import Config -from platypush.message import Message -from platypush.message.request import Request -from platypush.user import UserManager -from platypush.utils import get_redis_queue_name_by_message, get_ip_or_hostname - -_bus = None -_logger = None -user_manager = UserManager() - - -def bus(): - global _bus - if _bus is None: - _bus = RedisBus(redis_queue=current_app.config.get('redis_queue')) - return _bus - - -def logger(): - global _logger - if not _logger: - log_args = { - 'level': logging.INFO, - 'format': '%(asctime)-15s|%(levelname)5s|%(name)s|%(message)s', - } - - level = (Config.get('backend.http') or {}).get('logging') or ( - Config.get('logging') or {} - ).get('level') - filename = (Config.get('backend.http') or {}).get('filename') - - if level: - log_args['level'] = ( - getattr(logging, level.upper()) if isinstance(level, str) else level - ) - if filename: - log_args['filename'] = filename - - logging.basicConfig(**log_args) - _logger = logging.getLogger('platypush:web') - - return _logger - - -def get_message_response(msg): - redis = Redis(**bus().redis_args) - redis_queue = get_redis_queue_name_by_message(msg) - if not redis_queue: - return - - response = redis.blpop(redis_queue, timeout=60) - if response and len(response) > 1: - response = Message.build(response[1]) - else: - response = None - - return response - - -# noinspection PyProtectedMember -def get_http_port(): - from platypush.backend.http import HttpBackend - - http_conf = Config.get('backend.http') or {} - return http_conf.get('port', HttpBackend._DEFAULT_HTTP_PORT) - - -def send_message(msg, wait_for_response=True): - msg = Message.build(msg) - if msg is None: - return - - if isinstance(msg, Request): - msg.origin = 'http' - - if Config.get('token'): - msg.token = Config.get('token') - - bus().post(msg) - - if isinstance(msg, Request) and wait_for_response: - response = get_message_response(msg) - logger().debug('Processing response on the HTTP backend: {}'.format(response)) - - return response - - -def send_request(action, wait_for_response=True, **kwargs): - msg = {'type': 'request', 'action': action} - - if kwargs: - msg['args'] = kwargs - - return send_message(msg, wait_for_response=wait_for_response) - - -def _authenticate_token(): - token = Config.get('token') - user_token = None - - if 'X-Token' in request.headers: - user_token = request.headers['X-Token'] - elif 'Authorization' in request.headers and request.headers[ - 'Authorization' - ].startswith('Bearer '): - user_token = request.headers['Authorization'][7:] - elif 'token' in request.args: - user_token = request.args.get('token') - if not user_token: - return False - - try: - user_manager.validate_jwt_token(user_token) - return True - except Exception as e: - logger().debug(str(e)) - return token and user_token == token - - -def _authenticate_http(): - if not request.authorization: - return False - - username = request.authorization.username - password = request.authorization.password - return user_manager.authenticate_user(username, password) - - -def _authenticate_session(): - user_session_token = None - user = None - - if 'X-Session-Token' in request.headers: - user_session_token = request.headers['X-Session-Token'] - elif 'session_token' in request.args: - user_session_token = request.args.get('session_token') - elif 'session_token' in request.cookies: - user_session_token = request.cookies.get('session_token') - - if user_session_token: - user, _ = user_manager.authenticate_user_session(user_session_token) - - return user is not None - - -def _authenticate_csrf_token(): - user_session_token = None - - if 'X-Session-Token' in request.headers: - user_session_token = request.headers['X-Session-Token'] - elif 'session_token' in request.args: - user_session_token = request.args.get('session_token') - elif 'session_token' in request.cookies: - user_session_token = request.cookies.get('session_token') - - if user_session_token: - user, session = user_manager.authenticate_user_session(user_session_token) - else: - return False - - if user is None or session is None: - return False - - return ( - session.csrf_token is None - or request.form.get('csrf_token') == session.csrf_token - ) - - -def authenticate( - redirect_page='', - skip_auth_methods=None, - check_csrf_token=False, - json=False, -): - def on_auth_fail(has_users=True): - if json: - if has_users: - return ( - jsonify( - { - 'message': 'Not logged in', - } - ), - 401, - ) - - return ( - jsonify( - { - 'message': 'Please register a user through ' - 'the web panel first', - } - ), - 412, - ) - - target_page = 'login' if has_users else 'register' - return redirect(f'/{target_page}?redirect={redirect_page or request.url}', 307) - - def decorator(f): - @wraps(f) - def wrapper(*args, **kwargs): - n_users = user_manager.get_user_count() - skip_methods = skip_auth_methods or [] - - # User/pass HTTP authentication - http_auth_ok = True - if n_users > 0 and 'http' not in skip_methods: - http_auth_ok = _authenticate_http() - if http_auth_ok: - return f(*args, **kwargs) - - # Token-based authentication - token_auth_ok = True - if 'token' not in skip_methods: - token_auth_ok = _authenticate_token() - if token_auth_ok: - return f(*args, **kwargs) - - # Session token based authentication - session_auth_ok = True - if n_users > 0 and 'session' not in skip_methods: - session_auth_ok = _authenticate_session() - if session_auth_ok: - return f(*args, **kwargs) - - return on_auth_fail() - - # CSRF token check - if check_csrf_token: - csrf_check_ok = _authenticate_csrf_token() - if not csrf_check_ok: - return abort(403, 'Invalid or missing csrf_token') - - if n_users == 0 and 'session' not in skip_methods: - return on_auth_fail(has_users=False) - - if ( - ('http' not in skip_methods and http_auth_ok) - or ('token' not in skip_methods and token_auth_ok) - or ('session' not in skip_methods and session_auth_ok) - ): - return f(*args, **kwargs) - - return Response( - 'Authentication required', - 401, - {'WWW-Authenticate': 'Basic realm="Login required"'}, - ) - - return wrapper - - return decorator - - -def get_routes(): - routes_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'routes') - routes = [] - base_module = '.'.join(__name__.split('.')[:-1]) - - for path, _, files in os.walk(routes_dir): - for f in files: - if f.endswith('.py'): - mod_name = '.'.join( - ( - base_module - + '.' - + os.path.join(path, f) - .replace(os.path.dirname(__file__), '')[1:] - .replace(os.sep, '.') - ).split('.')[: (-2 if f == '__init__.py' else -1)] - ) - - try: - mod = importlib.import_module(mod_name) - if hasattr(mod, '__routes__'): - routes.extend(mod.__routes__) - except Exception as e: - logger().warning( - 'Could not import routes from {}/{}: {}: {}'.format( - path, f, type(e), str(e) - ) - ) - - return routes - - -def get_local_base_url(): - http_conf = Config.get('backend.http') or {} - bind_address = http_conf.get('bind_address') - if not bind_address or bind_address == '0.0.0.0': - bind_address = 'localhost' - - return '{proto}://{host}:{port}'.format( - proto=('https' if http_conf.get('ssl_cert') else 'http'), - host=bind_address, - port=get_http_port(), - ) - - -def get_remote_base_url(): - http_conf = Config.get('backend.http') or {} - return '{proto}://{host}:{port}'.format( - proto=('https' if http_conf.get('ssl_cert') else 'http'), - host=get_ip_or_hostname(), - port=get_http_port(), - ) - - -# vim:sw=4:ts=4:et: diff --git a/platypush/backend/http/app/utils/__init__.py b/platypush/backend/http/app/utils/__init__.py new file mode 100644 index 0000000000..08eae34465 --- /dev/null +++ b/platypush/backend/http/app/utils/__init__.py @@ -0,0 +1,37 @@ +from .auth import ( + authenticate, + authenticate_token, + authenticate_user_pass, + get_auth_status, +) +from .bus import bus, get_message_response, send_message, send_request +from .logger import logger +from .routes import ( + get_http_port, + get_ip_or_hostname, + get_local_base_url, + get_remote_base_url, + get_routes, +) +from .ws import get_ws_routes + +__all__ = [ + 'authenticate', + 'authenticate_token', + 'authenticate_user_pass', + 'bus', + 'get_auth_status', + 'get_http_port', + 'get_ip_or_hostname', + 'get_local_base_url', + 'get_message_response', + 'get_remote_base_url', + 'get_routes', + 'get_ws_routes', + 'logger', + 'send_message', + 'send_request', +] + + +# vim:sw=4:ts=4:et: diff --git a/platypush/backend/http/app/utils/auth/__init__.py b/platypush/backend/http/app/utils/auth/__init__.py new file mode 100644 index 0000000000..7d1b5ef65c --- /dev/null +++ b/platypush/backend/http/app/utils/auth/__init__.py @@ -0,0 +1,196 @@ +import base64 +from functools import wraps +from typing import Optional + +from flask import request, redirect, jsonify +from flask.wrappers import Response + +from platypush.config import Config +from platypush.user import UserManager + +from ..logger import logger +from .status import AuthStatus + +user_manager = UserManager() + + +def get_arg(req, name: str) -> Optional[str]: + # The Flask way + if hasattr(req, 'args'): + return req.args.get(name) + + # The Tornado way + if hasattr(req, 'arguments'): + arg = req.arguments.get(name) + if arg: + return arg[0].decode() + + return None + + +def get_cookie(req, name: str) -> Optional[str]: + cookie = req.cookies.get(name) + if not cookie: + return None + + # The Flask way + if isinstance(cookie, str): + return cookie + + # The Tornado way + return cookie.value + + +def authenticate_token(req): + token = Config.get('token') + user_token = None + + if 'X-Token' in req.headers: + user_token = req.headers['X-Token'] + elif 'Authorization' in req.headers and req.headers['Authorization'].startswith( + 'Bearer ' + ): + user_token = req.headers['Authorization'][7:] + else: + user_token = get_arg(req, 'token') + + if not user_token: + return False + + try: + user_manager.validate_jwt_token(user_token) + return True + except Exception as e: + logger().debug(str(e)) + return bool(token and user_token == token) + + +def authenticate_user_pass(req): + # Flask populates request.authorization + if hasattr(req, 'authorization'): + if not req.authorization: + return False + + username = req.authorization.username + password = req.authorization.password + + # Otherwise, check the Authorization header + elif 'Authorization' in req.headers and req.headers['Authorization'].startswith( + 'Basic ' + ): + auth = req.headers['Authorization'][6:] + try: + auth = base64.b64decode(auth) + except ValueError: + pass + + username, password = auth.decode().split(':', maxsplit=1) + else: + return False + + return user_manager.authenticate_user(username, password) + + +def authenticate_session(req): + user = None + + # Check the X-Session-Token header + user_session_token = req.headers.get('X-Session-Token') + + # Check the `session_token` query/body parameter + if not user_session_token: + user_session_token = get_arg(req, 'session_token') + + # Check the `session_token` cookie + if not user_session_token: + user_session_token = get_cookie(req, 'session_token') + + if user_session_token: + user, _ = user_manager.authenticate_user_session(user_session_token) + + return user is not None + + +def authenticate( + redirect_page='', + skip_auth_methods=None, + json=False, +): + """ + Authentication decorator for Flask routes. + """ + + def decorator(f): + @wraps(f) + def wrapper(*args, **kwargs): + auth_status = get_auth_status( + request, + skip_auth_methods=skip_auth_methods, + ) + + if auth_status == AuthStatus.OK: + return f(*args, **kwargs) + + if json: + return jsonify(auth_status.to_dict()), auth_status.value.code + + if auth_status == AuthStatus.NO_USERS: + return redirect( + f'/register?redirect={redirect_page or request.url}', 307 + ) + + if auth_status == AuthStatus.UNAUTHORIZED: + return redirect(f'/login?redirect={redirect_page or request.url}', 307) + + return Response( + 'Authentication required', + 401, + {'WWW-Authenticate': 'Basic realm="Login required"'}, + ) + + return wrapper + + return decorator + + +# pylint: disable=too-many-return-statements +def get_auth_status(req, skip_auth_methods=None) -> AuthStatus: + """ + Check against the available authentication methods (except those listed in + ``skip_auth_methods``) if the user is properly authenticated. + """ + + n_users = user_manager.get_user_count() + skip_methods = skip_auth_methods or [] + + # User/pass HTTP authentication + http_auth_ok = True + if n_users > 0 and 'http' not in skip_methods: + http_auth_ok = authenticate_user_pass(req) + if http_auth_ok: + return AuthStatus.OK + + # Token-based authentication + token_auth_ok = True + if 'token' not in skip_methods: + token_auth_ok = authenticate_token(req) + if token_auth_ok: + return AuthStatus.OK + + # Session token based authentication + session_auth_ok = True + if n_users > 0 and 'session' not in skip_methods: + return AuthStatus.OK if authenticate_session(req) else AuthStatus.UNAUTHORIZED + + # At least a user should be created before accessing an authenticated resource + if n_users == 0 and 'session' not in skip_methods: + return AuthStatus.NO_USERS + + if ( # pylint: disable=too-many-boolean-expressions + ('http' not in skip_methods and http_auth_ok) + or ('token' not in skip_methods and token_auth_ok) + or ('session' not in skip_methods and session_auth_ok) + ): + return AuthStatus.OK + + return AuthStatus.UNAUTHORIZED diff --git a/platypush/backend/http/app/utils/auth/status.py b/platypush/backend/http/app/utils/auth/status.py new file mode 100644 index 0000000000..ee64fb51d9 --- /dev/null +++ b/platypush/backend/http/app/utils/auth/status.py @@ -0,0 +1,21 @@ +from collections import namedtuple +from enum import Enum + + +StatusValue = namedtuple('StatusValue', ['code', 'message']) + + +class AuthStatus(Enum): + """ + Models the status of the authentication. + """ + + OK = StatusValue(200, 'OK') + UNAUTHORIZED = StatusValue(401, 'Unauthorized') + NO_USERS = StatusValue(412, 'Please create a user first') + + def to_dict(self): + return { + 'code': self.value[0], + 'message': self.value[1], + } diff --git a/platypush/backend/http/app/utils/bus.py b/platypush/backend/http/app/utils/bus.py new file mode 100644 index 0000000000..5834a46d80 --- /dev/null +++ b/platypush/backend/http/app/utils/bus.py @@ -0,0 +1,63 @@ +from flask import current_app +from redis import Redis + +from platypush.bus.redis import RedisBus +from platypush.config import Config +from platypush.message import Message +from platypush.message.request import Request +from platypush.utils import get_redis_queue_name_by_message + +from .logger import logger + +_bus = None + + +def bus(): + global _bus + if _bus is None: + _bus = RedisBus(redis_queue=current_app.config.get('redis_queue')) + return _bus + + +def send_message(msg, wait_for_response=True): + msg = Message.build(msg) + if msg is None: + return + + if isinstance(msg, Request): + msg.origin = 'http' + + if Config.get('token'): + msg.token = Config.get('token') + + bus().post(msg) + + if isinstance(msg, Request) and wait_for_response: + response = get_message_response(msg) + logger().debug('Processing response on the HTTP backend: {}'.format(response)) + + return response + + +def send_request(action, wait_for_response=True, **kwargs): + msg = {'type': 'request', 'action': action} + + if kwargs: + msg['args'] = kwargs + + return send_message(msg, wait_for_response=wait_for_response) + + +def get_message_response(msg): + redis = Redis(**bus().redis_args) + redis_queue = get_redis_queue_name_by_message(msg) + if not redis_queue: + return + + response = redis.blpop(redis_queue, timeout=60) + if response and len(response) > 1: + response = Message.build(response[1]) + else: + response = None + + return response diff --git a/platypush/backend/http/app/utils/logger.py b/platypush/backend/http/app/utils/logger.py new file mode 100644 index 0000000000..b731b5578e --- /dev/null +++ b/platypush/backend/http/app/utils/logger.py @@ -0,0 +1,31 @@ +import logging + +from platypush.config import Config + +_logger = None + + +def logger(): + global _logger + if not _logger: + log_args = { + 'level': logging.INFO, + 'format': '%(asctime)-15s|%(levelname)5s|%(name)s|%(message)s', + } + + level = (Config.get('backend.http') or {}).get('logging') or ( + Config.get('logging') or {} + ).get('level') + filename = (Config.get('backend.http') or {}).get('filename') + + if level: + log_args['level'] = ( + getattr(logging, level.upper()) if isinstance(level, str) else level + ) + if filename: + log_args['filename'] = filename + + logging.basicConfig(**log_args) + _logger = logging.getLogger('platypush:web') + + return _logger diff --git a/platypush/backend/http/app/utils/routes.py b/platypush/backend/http/app/utils/routes.py new file mode 100644 index 0000000000..b27fa4ca10 --- /dev/null +++ b/platypush/backend/http/app/utils/routes.py @@ -0,0 +1,59 @@ +import importlib +import inspect +import os +import pkgutil + +from platypush.backend import Backend +from platypush.config import Config +from platypush.utils import get_ip_or_hostname + +from .logger import logger + + +def get_http_port(): + from platypush.backend.http import HttpBackend + + http_conf = Config.get('backend.http') or {} + return http_conf.get('port', HttpBackend._DEFAULT_HTTP_PORT) + + +def get_routes(): + base_pkg = '.'.join([Backend.__module__, 'http', 'app', 'routes']) + base_dir = os.path.join( + os.path.dirname(inspect.getfile(Backend)), 'http', 'app', 'routes' + ) + routes = [] + + for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=base_pkg + '.'): + try: + module = importlib.import_module(mod_name) + if hasattr(module, '__routes__'): + routes.extend(module.__routes__) + except Exception as e: + logger.warning('Could not import module %s', mod_name) + logger.exception(e) + continue + + return routes + + +def get_local_base_url(): + http_conf = Config.get('backend.http') or {} + bind_address = http_conf.get('bind_address') + if not bind_address or bind_address == '0.0.0.0': + bind_address = 'localhost' + + return '{proto}://{host}:{port}'.format( + proto=('https' if http_conf.get('ssl_cert') else 'http'), + host=bind_address, + port=get_http_port(), + ) + + +def get_remote_base_url(): + http_conf = Config.get('backend.http') or {} + return '{proto}://{host}:{port}'.format( + proto=('https' if http_conf.get('ssl_cert') else 'http'), + host=get_ip_or_hostname(), + port=get_http_port(), + ) diff --git a/platypush/backend/http/ws/_scanner.py b/platypush/backend/http/app/utils/ws.py similarity index 70% rename from platypush/backend/http/ws/_scanner.py rename to platypush/backend/http/app/utils/ws.py index a42131bc8b..2b43b30881 100644 --- a/platypush/backend/http/ws/_scanner.py +++ b/platypush/backend/http/app/utils/ws.py @@ -5,18 +5,20 @@ from typing import List, Type import pkgutil -from ._base import WSRoute, logger +from ..ws import WSRoute, logger -def scan_routes() -> List[Type[WSRoute]]: +def get_ws_routes() -> List[Type[WSRoute]]: """ Scans for websocket route objects. """ + from platypush.backend.http import HttpBackend - base_dir = os.path.dirname(__file__) + base_pkg = '.'.join([HttpBackend.__module__, 'app', 'ws']) + base_dir = os.path.join(os.path.dirname(inspect.getfile(HttpBackend)), 'app', 'ws') routes = [] - for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=__package__ + '.'): + for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=base_pkg + '.'): try: module = importlib.import_module(mod_name) except Exception as e: diff --git a/platypush/backend/http/app/ws/__init__.py b/platypush/backend/http/app/ws/__init__.py new file mode 100644 index 0000000000..7d62e1fe05 --- /dev/null +++ b/platypush/backend/http/app/ws/__init__.py @@ -0,0 +1,3 @@ +from ._base import WSRoute, logger, pubsub_redis_topic + +__all__ = ['WSRoute', 'logger', 'pubsub_redis_topic'] diff --git a/platypush/backend/http/ws/_base.py b/platypush/backend/http/app/ws/_base.py similarity index 78% rename from platypush/backend/http/ws/_base.py rename to platypush/backend/http/app/ws/_base.py index 2e7eab2b41..d1596940e1 100644 --- a/platypush/backend/http/ws/_base.py +++ b/platypush/backend/http/app/ws/_base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractclassmethod +import json from logging import getLogger from threading import RLock, Thread from typing import Any, Generator, Iterable, Optional, Union @@ -8,7 +9,9 @@ from redis import ConnectionError as RedisConnectionError from tornado.ioloop import IOLoop from tornado.websocket import WebSocketHandler +from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status from platypush.config import Config +from platypush.message import Message from platypush.utils import get_redis logger = getLogger(__name__) @@ -32,7 +35,14 @@ class WSRoute(WebSocketHandler, Thread, ABC): @override def open(self, *_, **__): - logger.info('Client %s connected to %s', self.request.remote_ip, self.path()) + auth_status = get_auth_status(self.request) + if auth_status != AuthStatus.OK: + self.close(code=1008, reason=auth_status.value.message) # Policy Violation + return + + logger.info( + 'Client %s connected to %s', self.request.remote_ip, self.request.path + ) self.name = f'ws:{self.app_name()}@{self.request.remote_ip}' self.start() @@ -52,6 +62,10 @@ class WSRoute(WebSocketHandler, Thread, ABC): def path(cls) -> str: return f'/ws/{cls.app_name()}' + @property + def auth_required(self): + return True + def subscribe(self, *topics: str) -> None: with self._sub_lock: for topic in topics: @@ -78,7 +92,12 @@ class WSRoute(WebSocketHandler, Thread, ABC): except RedisConnectionError: return - def send(self, msg: Union[str, bytes]) -> None: + def send(self, msg: Union[str, bytes, dict, list, tuple, set]) -> None: + if isinstance(msg, (list, tuple, set)): + msg = list(msg) + if isinstance(msg, (list, dict)): + msg = json.dumps(msg, cls=Message.Encoder) + self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore self.write_message, msg ) @@ -99,7 +118,7 @@ class WSRoute(WebSocketHandler, Thread, ABC): logger.info( 'Client %s disconnected from %s, reason=%s, message=%s', self.request.remote_ip, - self.path(), + self.request.path, self.close_code, self.close_reason, ) diff --git a/platypush/backend/http/ws/events.py b/platypush/backend/http/app/ws/events.py similarity index 100% rename from platypush/backend/http/ws/events.py rename to platypush/backend/http/app/ws/events.py diff --git a/platypush/backend/http/ws/__init__.py b/platypush/backend/http/ws/__init__.py deleted file mode 100644 index 6b117c9c66..0000000000 --- a/platypush/backend/http/ws/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from ._base import WSRoute, logger, pubsub_redis_topic -from ._scanner import scan_routes - -__all__ = ['WSRoute', 'logger', 'pubsub_redis_topic', 'scan_routes']