[#260] Implemented authentication for websocket routes.

Plus, refactored the `backend.http.app.utils` module by breaking it down
into multiple components, as the module was starting to get too large.
This commit is contained in:
Fabio Manganiello 2023-05-09 00:03:11 +02:00
parent 2d4b179879
commit edb7197f71
Signed by: blacklight
GPG key ID: D90FBA7F76362774
13 changed files with 441 additions and 338 deletions

View file

@ -12,8 +12,8 @@ from tornado.ioloop import IOLoop
from platypush.backend import Backend
from platypush.backend.http.app import application
from platypush.backend.http.ws import scan_routes
from platypush.backend.http.ws.events import events_redis_topic
from platypush.backend.http.app.utils import get_ws_routes
from platypush.backend.http.app.ws.events import events_redis_topic
from platypush.bus.redis import RedisBus
from platypush.config import Config
@ -264,7 +264,7 @@ class HttpBackend(Backend):
container = WSGIContainer(application)
server = Application(
[
*[(route.path(), route) for route in scan_routes()],
*[(route.path(), route) for route in get_ws_routes()],
(r'.*', FallbackHandler, {'fallback': container}),
]
)

View file

@ -1,324 +0,0 @@
import importlib
import logging
import os
from functools import wraps
from flask import abort, request, redirect, jsonify, current_app
from flask.wrappers import Response
from redis import Redis
# NOTE: The HTTP service will *only* work on top of a Redis bus. The default
# internal bus service won't work as the web server will run in a different process.
from platypush.bus.redis import RedisBus
from platypush.config import Config
from platypush.message import Message
from platypush.message.request import Request
from platypush.user import UserManager
from platypush.utils import get_redis_queue_name_by_message, get_ip_or_hostname
_bus = None
_logger = None
user_manager = UserManager()
def bus():
global _bus
if _bus is None:
_bus = RedisBus(redis_queue=current_app.config.get('redis_queue'))
return _bus
def logger():
global _logger
if not _logger:
log_args = {
'level': logging.INFO,
'format': '%(asctime)-15s|%(levelname)5s|%(name)s|%(message)s',
}
level = (Config.get('backend.http') or {}).get('logging') or (
Config.get('logging') or {}
).get('level')
filename = (Config.get('backend.http') or {}).get('filename')
if level:
log_args['level'] = (
getattr(logging, level.upper()) if isinstance(level, str) else level
)
if filename:
log_args['filename'] = filename
logging.basicConfig(**log_args)
_logger = logging.getLogger('platypush:web')
return _logger
def get_message_response(msg):
redis = Redis(**bus().redis_args)
redis_queue = get_redis_queue_name_by_message(msg)
if not redis_queue:
return
response = redis.blpop(redis_queue, timeout=60)
if response and len(response) > 1:
response = Message.build(response[1])
else:
response = None
return response
# noinspection PyProtectedMember
def get_http_port():
from platypush.backend.http import HttpBackend
http_conf = Config.get('backend.http') or {}
return http_conf.get('port', HttpBackend._DEFAULT_HTTP_PORT)
def send_message(msg, wait_for_response=True):
msg = Message.build(msg)
if msg is None:
return
if isinstance(msg, Request):
msg.origin = 'http'
if Config.get('token'):
msg.token = Config.get('token')
bus().post(msg)
if isinstance(msg, Request) and wait_for_response:
response = get_message_response(msg)
logger().debug('Processing response on the HTTP backend: {}'.format(response))
return response
def send_request(action, wait_for_response=True, **kwargs):
msg = {'type': 'request', 'action': action}
if kwargs:
msg['args'] = kwargs
return send_message(msg, wait_for_response=wait_for_response)
def _authenticate_token():
token = Config.get('token')
user_token = None
if 'X-Token' in request.headers:
user_token = request.headers['X-Token']
elif 'Authorization' in request.headers and request.headers[
'Authorization'
].startswith('Bearer '):
user_token = request.headers['Authorization'][7:]
elif 'token' in request.args:
user_token = request.args.get('token')
if not user_token:
return False
try:
user_manager.validate_jwt_token(user_token)
return True
except Exception as e:
logger().debug(str(e))
return token and user_token == token
def _authenticate_http():
if not request.authorization:
return False
username = request.authorization.username
password = request.authorization.password
return user_manager.authenticate_user(username, password)
def _authenticate_session():
user_session_token = None
user = None
if 'X-Session-Token' in request.headers:
user_session_token = request.headers['X-Session-Token']
elif 'session_token' in request.args:
user_session_token = request.args.get('session_token')
elif 'session_token' in request.cookies:
user_session_token = request.cookies.get('session_token')
if user_session_token:
user, _ = user_manager.authenticate_user_session(user_session_token)
return user is not None
def _authenticate_csrf_token():
user_session_token = None
if 'X-Session-Token' in request.headers:
user_session_token = request.headers['X-Session-Token']
elif 'session_token' in request.args:
user_session_token = request.args.get('session_token')
elif 'session_token' in request.cookies:
user_session_token = request.cookies.get('session_token')
if user_session_token:
user, session = user_manager.authenticate_user_session(user_session_token)
else:
return False
if user is None or session is None:
return False
return (
session.csrf_token is None
or request.form.get('csrf_token') == session.csrf_token
)
def authenticate(
redirect_page='',
skip_auth_methods=None,
check_csrf_token=False,
json=False,
):
def on_auth_fail(has_users=True):
if json:
if has_users:
return (
jsonify(
{
'message': 'Not logged in',
}
),
401,
)
return (
jsonify(
{
'message': 'Please register a user through '
'the web panel first',
}
),
412,
)
target_page = 'login' if has_users else 'register'
return redirect(f'/{target_page}?redirect={redirect_page or request.url}', 307)
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
n_users = user_manager.get_user_count()
skip_methods = skip_auth_methods or []
# User/pass HTTP authentication
http_auth_ok = True
if n_users > 0 and 'http' not in skip_methods:
http_auth_ok = _authenticate_http()
if http_auth_ok:
return f(*args, **kwargs)
# Token-based authentication
token_auth_ok = True
if 'token' not in skip_methods:
token_auth_ok = _authenticate_token()
if token_auth_ok:
return f(*args, **kwargs)
# Session token based authentication
session_auth_ok = True
if n_users > 0 and 'session' not in skip_methods:
session_auth_ok = _authenticate_session()
if session_auth_ok:
return f(*args, **kwargs)
return on_auth_fail()
# CSRF token check
if check_csrf_token:
csrf_check_ok = _authenticate_csrf_token()
if not csrf_check_ok:
return abort(403, 'Invalid or missing csrf_token')
if n_users == 0 and 'session' not in skip_methods:
return on_auth_fail(has_users=False)
if (
('http' not in skip_methods and http_auth_ok)
or ('token' not in skip_methods and token_auth_ok)
or ('session' not in skip_methods and session_auth_ok)
):
return f(*args, **kwargs)
return Response(
'Authentication required',
401,
{'WWW-Authenticate': 'Basic realm="Login required"'},
)
return wrapper
return decorator
def get_routes():
routes_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'routes')
routes = []
base_module = '.'.join(__name__.split('.')[:-1])
for path, _, files in os.walk(routes_dir):
for f in files:
if f.endswith('.py'):
mod_name = '.'.join(
(
base_module
+ '.'
+ os.path.join(path, f)
.replace(os.path.dirname(__file__), '')[1:]
.replace(os.sep, '.')
).split('.')[: (-2 if f == '__init__.py' else -1)]
)
try:
mod = importlib.import_module(mod_name)
if hasattr(mod, '__routes__'):
routes.extend(mod.__routes__)
except Exception as e:
logger().warning(
'Could not import routes from {}/{}: {}: {}'.format(
path, f, type(e), str(e)
)
)
return routes
def get_local_base_url():
http_conf = Config.get('backend.http') or {}
bind_address = http_conf.get('bind_address')
if not bind_address or bind_address == '0.0.0.0':
bind_address = 'localhost'
return '{proto}://{host}:{port}'.format(
proto=('https' if http_conf.get('ssl_cert') else 'http'),
host=bind_address,
port=get_http_port(),
)
def get_remote_base_url():
http_conf = Config.get('backend.http') or {}
return '{proto}://{host}:{port}'.format(
proto=('https' if http_conf.get('ssl_cert') else 'http'),
host=get_ip_or_hostname(),
port=get_http_port(),
)
# vim:sw=4:ts=4:et:

View file

@ -0,0 +1,37 @@
from .auth import (
authenticate,
authenticate_token,
authenticate_user_pass,
get_auth_status,
)
from .bus import bus, get_message_response, send_message, send_request
from .logger import logger
from .routes import (
get_http_port,
get_ip_or_hostname,
get_local_base_url,
get_remote_base_url,
get_routes,
)
from .ws import get_ws_routes
__all__ = [
'authenticate',
'authenticate_token',
'authenticate_user_pass',
'bus',
'get_auth_status',
'get_http_port',
'get_ip_or_hostname',
'get_local_base_url',
'get_message_response',
'get_remote_base_url',
'get_routes',
'get_ws_routes',
'logger',
'send_message',
'send_request',
]
# vim:sw=4:ts=4:et:

View file

@ -0,0 +1,196 @@
import base64
from functools import wraps
from typing import Optional
from flask import request, redirect, jsonify
from flask.wrappers import Response
from platypush.config import Config
from platypush.user import UserManager
from ..logger import logger
from .status import AuthStatus
user_manager = UserManager()
def get_arg(req, name: str) -> Optional[str]:
# The Flask way
if hasattr(req, 'args'):
return req.args.get(name)
# The Tornado way
if hasattr(req, 'arguments'):
arg = req.arguments.get(name)
if arg:
return arg[0].decode()
return None
def get_cookie(req, name: str) -> Optional[str]:
cookie = req.cookies.get(name)
if not cookie:
return None
# The Flask way
if isinstance(cookie, str):
return cookie
# The Tornado way
return cookie.value
def authenticate_token(req):
token = Config.get('token')
user_token = None
if 'X-Token' in req.headers:
user_token = req.headers['X-Token']
elif 'Authorization' in req.headers and req.headers['Authorization'].startswith(
'Bearer '
):
user_token = req.headers['Authorization'][7:]
else:
user_token = get_arg(req, 'token')
if not user_token:
return False
try:
user_manager.validate_jwt_token(user_token)
return True
except Exception as e:
logger().debug(str(e))
return bool(token and user_token == token)
def authenticate_user_pass(req):
# Flask populates request.authorization
if hasattr(req, 'authorization'):
if not req.authorization:
return False
username = req.authorization.username
password = req.authorization.password
# Otherwise, check the Authorization header
elif 'Authorization' in req.headers and req.headers['Authorization'].startswith(
'Basic '
):
auth = req.headers['Authorization'][6:]
try:
auth = base64.b64decode(auth)
except ValueError:
pass
username, password = auth.decode().split(':', maxsplit=1)
else:
return False
return user_manager.authenticate_user(username, password)
def authenticate_session(req):
user = None
# Check the X-Session-Token header
user_session_token = req.headers.get('X-Session-Token')
# Check the `session_token` query/body parameter
if not user_session_token:
user_session_token = get_arg(req, 'session_token')
# Check the `session_token` cookie
if not user_session_token:
user_session_token = get_cookie(req, 'session_token')
if user_session_token:
user, _ = user_manager.authenticate_user_session(user_session_token)
return user is not None
def authenticate(
redirect_page='',
skip_auth_methods=None,
json=False,
):
"""
Authentication decorator for Flask routes.
"""
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
auth_status = get_auth_status(
request,
skip_auth_methods=skip_auth_methods,
)
if auth_status == AuthStatus.OK:
return f(*args, **kwargs)
if json:
return jsonify(auth_status.to_dict()), auth_status.value.code
if auth_status == AuthStatus.NO_USERS:
return redirect(
f'/register?redirect={redirect_page or request.url}', 307
)
if auth_status == AuthStatus.UNAUTHORIZED:
return redirect(f'/login?redirect={redirect_page or request.url}', 307)
return Response(
'Authentication required',
401,
{'WWW-Authenticate': 'Basic realm="Login required"'},
)
return wrapper
return decorator
# pylint: disable=too-many-return-statements
def get_auth_status(req, skip_auth_methods=None) -> AuthStatus:
"""
Check against the available authentication methods (except those listed in
``skip_auth_methods``) if the user is properly authenticated.
"""
n_users = user_manager.get_user_count()
skip_methods = skip_auth_methods or []
# User/pass HTTP authentication
http_auth_ok = True
if n_users > 0 and 'http' not in skip_methods:
http_auth_ok = authenticate_user_pass(req)
if http_auth_ok:
return AuthStatus.OK
# Token-based authentication
token_auth_ok = True
if 'token' not in skip_methods:
token_auth_ok = authenticate_token(req)
if token_auth_ok:
return AuthStatus.OK
# Session token based authentication
session_auth_ok = True
if n_users > 0 and 'session' not in skip_methods:
return AuthStatus.OK if authenticate_session(req) else AuthStatus.UNAUTHORIZED
# At least a user should be created before accessing an authenticated resource
if n_users == 0 and 'session' not in skip_methods:
return AuthStatus.NO_USERS
if ( # pylint: disable=too-many-boolean-expressions
('http' not in skip_methods and http_auth_ok)
or ('token' not in skip_methods and token_auth_ok)
or ('session' not in skip_methods and session_auth_ok)
):
return AuthStatus.OK
return AuthStatus.UNAUTHORIZED

View file

@ -0,0 +1,21 @@
from collections import namedtuple
from enum import Enum
StatusValue = namedtuple('StatusValue', ['code', 'message'])
class AuthStatus(Enum):
"""
Models the status of the authentication.
"""
OK = StatusValue(200, 'OK')
UNAUTHORIZED = StatusValue(401, 'Unauthorized')
NO_USERS = StatusValue(412, 'Please create a user first')
def to_dict(self):
return {
'code': self.value[0],
'message': self.value[1],
}

View file

@ -0,0 +1,63 @@
from flask import current_app
from redis import Redis
from platypush.bus.redis import RedisBus
from platypush.config import Config
from platypush.message import Message
from platypush.message.request import Request
from platypush.utils import get_redis_queue_name_by_message
from .logger import logger
_bus = None
def bus():
global _bus
if _bus is None:
_bus = RedisBus(redis_queue=current_app.config.get('redis_queue'))
return _bus
def send_message(msg, wait_for_response=True):
msg = Message.build(msg)
if msg is None:
return
if isinstance(msg, Request):
msg.origin = 'http'
if Config.get('token'):
msg.token = Config.get('token')
bus().post(msg)
if isinstance(msg, Request) and wait_for_response:
response = get_message_response(msg)
logger().debug('Processing response on the HTTP backend: {}'.format(response))
return response
def send_request(action, wait_for_response=True, **kwargs):
msg = {'type': 'request', 'action': action}
if kwargs:
msg['args'] = kwargs
return send_message(msg, wait_for_response=wait_for_response)
def get_message_response(msg):
redis = Redis(**bus().redis_args)
redis_queue = get_redis_queue_name_by_message(msg)
if not redis_queue:
return
response = redis.blpop(redis_queue, timeout=60)
if response and len(response) > 1:
response = Message.build(response[1])
else:
response = None
return response

View file

@ -0,0 +1,31 @@
import logging
from platypush.config import Config
_logger = None
def logger():
global _logger
if not _logger:
log_args = {
'level': logging.INFO,
'format': '%(asctime)-15s|%(levelname)5s|%(name)s|%(message)s',
}
level = (Config.get('backend.http') or {}).get('logging') or (
Config.get('logging') or {}
).get('level')
filename = (Config.get('backend.http') or {}).get('filename')
if level:
log_args['level'] = (
getattr(logging, level.upper()) if isinstance(level, str) else level
)
if filename:
log_args['filename'] = filename
logging.basicConfig(**log_args)
_logger = logging.getLogger('platypush:web')
return _logger

View file

@ -0,0 +1,59 @@
import importlib
import inspect
import os
import pkgutil
from platypush.backend import Backend
from platypush.config import Config
from platypush.utils import get_ip_or_hostname
from .logger import logger
def get_http_port():
from platypush.backend.http import HttpBackend
http_conf = Config.get('backend.http') or {}
return http_conf.get('port', HttpBackend._DEFAULT_HTTP_PORT)
def get_routes():
base_pkg = '.'.join([Backend.__module__, 'http', 'app', 'routes'])
base_dir = os.path.join(
os.path.dirname(inspect.getfile(Backend)), 'http', 'app', 'routes'
)
routes = []
for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=base_pkg + '.'):
try:
module = importlib.import_module(mod_name)
if hasattr(module, '__routes__'):
routes.extend(module.__routes__)
except Exception as e:
logger.warning('Could not import module %s', mod_name)
logger.exception(e)
continue
return routes
def get_local_base_url():
http_conf = Config.get('backend.http') or {}
bind_address = http_conf.get('bind_address')
if not bind_address or bind_address == '0.0.0.0':
bind_address = 'localhost'
return '{proto}://{host}:{port}'.format(
proto=('https' if http_conf.get('ssl_cert') else 'http'),
host=bind_address,
port=get_http_port(),
)
def get_remote_base_url():
http_conf = Config.get('backend.http') or {}
return '{proto}://{host}:{port}'.format(
proto=('https' if http_conf.get('ssl_cert') else 'http'),
host=get_ip_or_hostname(),
port=get_http_port(),
)

View file

@ -5,18 +5,20 @@ from typing import List, Type
import pkgutil
from ._base import WSRoute, logger
from ..ws import WSRoute, logger
def scan_routes() -> List[Type[WSRoute]]:
def get_ws_routes() -> List[Type[WSRoute]]:
"""
Scans for websocket route objects.
"""
from platypush.backend.http import HttpBackend
base_dir = os.path.dirname(__file__)
base_pkg = '.'.join([HttpBackend.__module__, 'app', 'ws'])
base_dir = os.path.join(os.path.dirname(inspect.getfile(HttpBackend)), 'app', 'ws')
routes = []
for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=__package__ + '.'):
for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=base_pkg + '.'):
try:
module = importlib.import_module(mod_name)
except Exception as e:

View file

@ -0,0 +1,3 @@
from ._base import WSRoute, logger, pubsub_redis_topic
__all__ = ['WSRoute', 'logger', 'pubsub_redis_topic']

View file

@ -1,4 +1,5 @@
from abc import ABC, abstractclassmethod
import json
from logging import getLogger
from threading import RLock, Thread
from typing import Any, Generator, Iterable, Optional, Union
@ -8,7 +9,9 @@ from redis import ConnectionError as RedisConnectionError
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
from platypush.config import Config
from platypush.message import Message
from platypush.utils import get_redis
logger = getLogger(__name__)
@ -32,7 +35,14 @@ class WSRoute(WebSocketHandler, Thread, ABC):
@override
def open(self, *_, **__):
logger.info('Client %s connected to %s', self.request.remote_ip, self.path())
auth_status = get_auth_status(self.request)
if auth_status != AuthStatus.OK:
self.close(code=1008, reason=auth_status.value.message) # Policy Violation
return
logger.info(
'Client %s connected to %s', self.request.remote_ip, self.request.path
)
self.name = f'ws:{self.app_name()}@{self.request.remote_ip}'
self.start()
@ -52,6 +62,10 @@ class WSRoute(WebSocketHandler, Thread, ABC):
def path(cls) -> str:
return f'/ws/{cls.app_name()}'
@property
def auth_required(self):
return True
def subscribe(self, *topics: str) -> None:
with self._sub_lock:
for topic in topics:
@ -78,7 +92,12 @@ class WSRoute(WebSocketHandler, Thread, ABC):
except RedisConnectionError:
return
def send(self, msg: Union[str, bytes]) -> None:
def send(self, msg: Union[str, bytes, dict, list, tuple, set]) -> None:
if isinstance(msg, (list, tuple, set)):
msg = list(msg)
if isinstance(msg, (list, dict)):
msg = json.dumps(msg, cls=Message.Encoder)
self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore
self.write_message, msg
)
@ -99,7 +118,7 @@ class WSRoute(WebSocketHandler, Thread, ABC):
logger.info(
'Client %s disconnected from %s, reason=%s, message=%s',
self.request.remote_ip,
self.path(),
self.request.path,
self.close_code,
self.close_reason,
)

View file

@ -1,4 +0,0 @@
from ._base import WSRoute, logger, pubsub_redis_topic
from ._scanner import scan_routes
__all__ = ['WSRoute', 'logger', 'pubsub_redis_topic', 'scan_routes']