forked from platypush/platypush
[#260] Implemented authentication for websocket routes.
Plus, refactored the `backend.http.app.utils` module by breaking it down into multiple components, as the module was starting to get too large.
This commit is contained in:
parent
2d4b179879
commit
edb7197f71
13 changed files with 441 additions and 338 deletions
|
@ -12,8 +12,8 @@ from tornado.ioloop import IOLoop
|
|||
|
||||
from platypush.backend import Backend
|
||||
from platypush.backend.http.app import application
|
||||
from platypush.backend.http.ws import scan_routes
|
||||
from platypush.backend.http.ws.events import events_redis_topic
|
||||
from platypush.backend.http.app.utils import get_ws_routes
|
||||
from platypush.backend.http.app.ws.events import events_redis_topic
|
||||
|
||||
from platypush.bus.redis import RedisBus
|
||||
from platypush.config import Config
|
||||
|
@ -264,7 +264,7 @@ class HttpBackend(Backend):
|
|||
container = WSGIContainer(application)
|
||||
server = Application(
|
||||
[
|
||||
*[(route.path(), route) for route in scan_routes()],
|
||||
*[(route.path(), route) for route in get_ws_routes()],
|
||||
(r'.*', FallbackHandler, {'fallback': container}),
|
||||
]
|
||||
)
|
||||
|
|
|
@ -1,324 +0,0 @@
|
|||
import importlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from functools import wraps
|
||||
from flask import abort, request, redirect, jsonify, current_app
|
||||
from flask.wrappers import Response
|
||||
from redis import Redis
|
||||
|
||||
# NOTE: The HTTP service will *only* work on top of a Redis bus. The default
|
||||
# internal bus service won't work as the web server will run in a different process.
|
||||
from platypush.bus.redis import RedisBus
|
||||
|
||||
from platypush.config import Config
|
||||
from platypush.message import Message
|
||||
from platypush.message.request import Request
|
||||
from platypush.user import UserManager
|
||||
from platypush.utils import get_redis_queue_name_by_message, get_ip_or_hostname
|
||||
|
||||
_bus = None
|
||||
_logger = None
|
||||
user_manager = UserManager()
|
||||
|
||||
|
||||
def bus():
|
||||
global _bus
|
||||
if _bus is None:
|
||||
_bus = RedisBus(redis_queue=current_app.config.get('redis_queue'))
|
||||
return _bus
|
||||
|
||||
|
||||
def logger():
|
||||
global _logger
|
||||
if not _logger:
|
||||
log_args = {
|
||||
'level': logging.INFO,
|
||||
'format': '%(asctime)-15s|%(levelname)5s|%(name)s|%(message)s',
|
||||
}
|
||||
|
||||
level = (Config.get('backend.http') or {}).get('logging') or (
|
||||
Config.get('logging') or {}
|
||||
).get('level')
|
||||
filename = (Config.get('backend.http') or {}).get('filename')
|
||||
|
||||
if level:
|
||||
log_args['level'] = (
|
||||
getattr(logging, level.upper()) if isinstance(level, str) else level
|
||||
)
|
||||
if filename:
|
||||
log_args['filename'] = filename
|
||||
|
||||
logging.basicConfig(**log_args)
|
||||
_logger = logging.getLogger('platypush:web')
|
||||
|
||||
return _logger
|
||||
|
||||
|
||||
def get_message_response(msg):
|
||||
redis = Redis(**bus().redis_args)
|
||||
redis_queue = get_redis_queue_name_by_message(msg)
|
||||
if not redis_queue:
|
||||
return
|
||||
|
||||
response = redis.blpop(redis_queue, timeout=60)
|
||||
if response and len(response) > 1:
|
||||
response = Message.build(response[1])
|
||||
else:
|
||||
response = None
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def get_http_port():
|
||||
from platypush.backend.http import HttpBackend
|
||||
|
||||
http_conf = Config.get('backend.http') or {}
|
||||
return http_conf.get('port', HttpBackend._DEFAULT_HTTP_PORT)
|
||||
|
||||
|
||||
def send_message(msg, wait_for_response=True):
|
||||
msg = Message.build(msg)
|
||||
if msg is None:
|
||||
return
|
||||
|
||||
if isinstance(msg, Request):
|
||||
msg.origin = 'http'
|
||||
|
||||
if Config.get('token'):
|
||||
msg.token = Config.get('token')
|
||||
|
||||
bus().post(msg)
|
||||
|
||||
if isinstance(msg, Request) and wait_for_response:
|
||||
response = get_message_response(msg)
|
||||
logger().debug('Processing response on the HTTP backend: {}'.format(response))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def send_request(action, wait_for_response=True, **kwargs):
|
||||
msg = {'type': 'request', 'action': action}
|
||||
|
||||
if kwargs:
|
||||
msg['args'] = kwargs
|
||||
|
||||
return send_message(msg, wait_for_response=wait_for_response)
|
||||
|
||||
|
||||
def _authenticate_token():
|
||||
token = Config.get('token')
|
||||
user_token = None
|
||||
|
||||
if 'X-Token' in request.headers:
|
||||
user_token = request.headers['X-Token']
|
||||
elif 'Authorization' in request.headers and request.headers[
|
||||
'Authorization'
|
||||
].startswith('Bearer '):
|
||||
user_token = request.headers['Authorization'][7:]
|
||||
elif 'token' in request.args:
|
||||
user_token = request.args.get('token')
|
||||
if not user_token:
|
||||
return False
|
||||
|
||||
try:
|
||||
user_manager.validate_jwt_token(user_token)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger().debug(str(e))
|
||||
return token and user_token == token
|
||||
|
||||
|
||||
def _authenticate_http():
|
||||
if not request.authorization:
|
||||
return False
|
||||
|
||||
username = request.authorization.username
|
||||
password = request.authorization.password
|
||||
return user_manager.authenticate_user(username, password)
|
||||
|
||||
|
||||
def _authenticate_session():
|
||||
user_session_token = None
|
||||
user = None
|
||||
|
||||
if 'X-Session-Token' in request.headers:
|
||||
user_session_token = request.headers['X-Session-Token']
|
||||
elif 'session_token' in request.args:
|
||||
user_session_token = request.args.get('session_token')
|
||||
elif 'session_token' in request.cookies:
|
||||
user_session_token = request.cookies.get('session_token')
|
||||
|
||||
if user_session_token:
|
||||
user, _ = user_manager.authenticate_user_session(user_session_token)
|
||||
|
||||
return user is not None
|
||||
|
||||
|
||||
def _authenticate_csrf_token():
|
||||
user_session_token = None
|
||||
|
||||
if 'X-Session-Token' in request.headers:
|
||||
user_session_token = request.headers['X-Session-Token']
|
||||
elif 'session_token' in request.args:
|
||||
user_session_token = request.args.get('session_token')
|
||||
elif 'session_token' in request.cookies:
|
||||
user_session_token = request.cookies.get('session_token')
|
||||
|
||||
if user_session_token:
|
||||
user, session = user_manager.authenticate_user_session(user_session_token)
|
||||
else:
|
||||
return False
|
||||
|
||||
if user is None or session is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
session.csrf_token is None
|
||||
or request.form.get('csrf_token') == session.csrf_token
|
||||
)
|
||||
|
||||
|
||||
def authenticate(
|
||||
redirect_page='',
|
||||
skip_auth_methods=None,
|
||||
check_csrf_token=False,
|
||||
json=False,
|
||||
):
|
||||
def on_auth_fail(has_users=True):
|
||||
if json:
|
||||
if has_users:
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
'message': 'Not logged in',
|
||||
}
|
||||
),
|
||||
401,
|
||||
)
|
||||
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
'message': 'Please register a user through '
|
||||
'the web panel first',
|
||||
}
|
||||
),
|
||||
412,
|
||||
)
|
||||
|
||||
target_page = 'login' if has_users else 'register'
|
||||
return redirect(f'/{target_page}?redirect={redirect_page or request.url}', 307)
|
||||
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
n_users = user_manager.get_user_count()
|
||||
skip_methods = skip_auth_methods or []
|
||||
|
||||
# User/pass HTTP authentication
|
||||
http_auth_ok = True
|
||||
if n_users > 0 and 'http' not in skip_methods:
|
||||
http_auth_ok = _authenticate_http()
|
||||
if http_auth_ok:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# Token-based authentication
|
||||
token_auth_ok = True
|
||||
if 'token' not in skip_methods:
|
||||
token_auth_ok = _authenticate_token()
|
||||
if token_auth_ok:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# Session token based authentication
|
||||
session_auth_ok = True
|
||||
if n_users > 0 and 'session' not in skip_methods:
|
||||
session_auth_ok = _authenticate_session()
|
||||
if session_auth_ok:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return on_auth_fail()
|
||||
|
||||
# CSRF token check
|
||||
if check_csrf_token:
|
||||
csrf_check_ok = _authenticate_csrf_token()
|
||||
if not csrf_check_ok:
|
||||
return abort(403, 'Invalid or missing csrf_token')
|
||||
|
||||
if n_users == 0 and 'session' not in skip_methods:
|
||||
return on_auth_fail(has_users=False)
|
||||
|
||||
if (
|
||||
('http' not in skip_methods and http_auth_ok)
|
||||
or ('token' not in skip_methods and token_auth_ok)
|
||||
or ('session' not in skip_methods and session_auth_ok)
|
||||
):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return Response(
|
||||
'Authentication required',
|
||||
401,
|
||||
{'WWW-Authenticate': 'Basic realm="Login required"'},
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_routes():
|
||||
routes_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'routes')
|
||||
routes = []
|
||||
base_module = '.'.join(__name__.split('.')[:-1])
|
||||
|
||||
for path, _, files in os.walk(routes_dir):
|
||||
for f in files:
|
||||
if f.endswith('.py'):
|
||||
mod_name = '.'.join(
|
||||
(
|
||||
base_module
|
||||
+ '.'
|
||||
+ os.path.join(path, f)
|
||||
.replace(os.path.dirname(__file__), '')[1:]
|
||||
.replace(os.sep, '.')
|
||||
).split('.')[: (-2 if f == '__init__.py' else -1)]
|
||||
)
|
||||
|
||||
try:
|
||||
mod = importlib.import_module(mod_name)
|
||||
if hasattr(mod, '__routes__'):
|
||||
routes.extend(mod.__routes__)
|
||||
except Exception as e:
|
||||
logger().warning(
|
||||
'Could not import routes from {}/{}: {}: {}'.format(
|
||||
path, f, type(e), str(e)
|
||||
)
|
||||
)
|
||||
|
||||
return routes
|
||||
|
||||
|
||||
def get_local_base_url():
|
||||
http_conf = Config.get('backend.http') or {}
|
||||
bind_address = http_conf.get('bind_address')
|
||||
if not bind_address or bind_address == '0.0.0.0':
|
||||
bind_address = 'localhost'
|
||||
|
||||
return '{proto}://{host}:{port}'.format(
|
||||
proto=('https' if http_conf.get('ssl_cert') else 'http'),
|
||||
host=bind_address,
|
||||
port=get_http_port(),
|
||||
)
|
||||
|
||||
|
||||
def get_remote_base_url():
|
||||
http_conf = Config.get('backend.http') or {}
|
||||
return '{proto}://{host}:{port}'.format(
|
||||
proto=('https' if http_conf.get('ssl_cert') else 'http'),
|
||||
host=get_ip_or_hostname(),
|
||||
port=get_http_port(),
|
||||
)
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
37
platypush/backend/http/app/utils/__init__.py
Normal file
37
platypush/backend/http/app/utils/__init__.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from .auth import (
|
||||
authenticate,
|
||||
authenticate_token,
|
||||
authenticate_user_pass,
|
||||
get_auth_status,
|
||||
)
|
||||
from .bus import bus, get_message_response, send_message, send_request
|
||||
from .logger import logger
|
||||
from .routes import (
|
||||
get_http_port,
|
||||
get_ip_or_hostname,
|
||||
get_local_base_url,
|
||||
get_remote_base_url,
|
||||
get_routes,
|
||||
)
|
||||
from .ws import get_ws_routes
|
||||
|
||||
__all__ = [
|
||||
'authenticate',
|
||||
'authenticate_token',
|
||||
'authenticate_user_pass',
|
||||
'bus',
|
||||
'get_auth_status',
|
||||
'get_http_port',
|
||||
'get_ip_or_hostname',
|
||||
'get_local_base_url',
|
||||
'get_message_response',
|
||||
'get_remote_base_url',
|
||||
'get_routes',
|
||||
'get_ws_routes',
|
||||
'logger',
|
||||
'send_message',
|
||||
'send_request',
|
||||
]
|
||||
|
||||
|
||||
# vim:sw=4:ts=4:et:
|
196
platypush/backend/http/app/utils/auth/__init__.py
Normal file
196
platypush/backend/http/app/utils/auth/__init__.py
Normal file
|
@ -0,0 +1,196 @@
|
|||
import base64
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import request, redirect, jsonify
|
||||
from flask.wrappers import Response
|
||||
|
||||
from platypush.config import Config
|
||||
from platypush.user import UserManager
|
||||
|
||||
from ..logger import logger
|
||||
from .status import AuthStatus
|
||||
|
||||
user_manager = UserManager()
|
||||
|
||||
|
||||
def get_arg(req, name: str) -> Optional[str]:
|
||||
# The Flask way
|
||||
if hasattr(req, 'args'):
|
||||
return req.args.get(name)
|
||||
|
||||
# The Tornado way
|
||||
if hasattr(req, 'arguments'):
|
||||
arg = req.arguments.get(name)
|
||||
if arg:
|
||||
return arg[0].decode()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_cookie(req, name: str) -> Optional[str]:
|
||||
cookie = req.cookies.get(name)
|
||||
if not cookie:
|
||||
return None
|
||||
|
||||
# The Flask way
|
||||
if isinstance(cookie, str):
|
||||
return cookie
|
||||
|
||||
# The Tornado way
|
||||
return cookie.value
|
||||
|
||||
|
||||
def authenticate_token(req):
|
||||
token = Config.get('token')
|
||||
user_token = None
|
||||
|
||||
if 'X-Token' in req.headers:
|
||||
user_token = req.headers['X-Token']
|
||||
elif 'Authorization' in req.headers and req.headers['Authorization'].startswith(
|
||||
'Bearer '
|
||||
):
|
||||
user_token = req.headers['Authorization'][7:]
|
||||
else:
|
||||
user_token = get_arg(req, 'token')
|
||||
|
||||
if not user_token:
|
||||
return False
|
||||
|
||||
try:
|
||||
user_manager.validate_jwt_token(user_token)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger().debug(str(e))
|
||||
return bool(token and user_token == token)
|
||||
|
||||
|
||||
def authenticate_user_pass(req):
|
||||
# Flask populates request.authorization
|
||||
if hasattr(req, 'authorization'):
|
||||
if not req.authorization:
|
||||
return False
|
||||
|
||||
username = req.authorization.username
|
||||
password = req.authorization.password
|
||||
|
||||
# Otherwise, check the Authorization header
|
||||
elif 'Authorization' in req.headers and req.headers['Authorization'].startswith(
|
||||
'Basic '
|
||||
):
|
||||
auth = req.headers['Authorization'][6:]
|
||||
try:
|
||||
auth = base64.b64decode(auth)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
username, password = auth.decode().split(':', maxsplit=1)
|
||||
else:
|
||||
return False
|
||||
|
||||
return user_manager.authenticate_user(username, password)
|
||||
|
||||
|
||||
def authenticate_session(req):
|
||||
user = None
|
||||
|
||||
# Check the X-Session-Token header
|
||||
user_session_token = req.headers.get('X-Session-Token')
|
||||
|
||||
# Check the `session_token` query/body parameter
|
||||
if not user_session_token:
|
||||
user_session_token = get_arg(req, 'session_token')
|
||||
|
||||
# Check the `session_token` cookie
|
||||
if not user_session_token:
|
||||
user_session_token = get_cookie(req, 'session_token')
|
||||
|
||||
if user_session_token:
|
||||
user, _ = user_manager.authenticate_user_session(user_session_token)
|
||||
|
||||
return user is not None
|
||||
|
||||
|
||||
def authenticate(
|
||||
redirect_page='',
|
||||
skip_auth_methods=None,
|
||||
json=False,
|
||||
):
|
||||
"""
|
||||
Authentication decorator for Flask routes.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
auth_status = get_auth_status(
|
||||
request,
|
||||
skip_auth_methods=skip_auth_methods,
|
||||
)
|
||||
|
||||
if auth_status == AuthStatus.OK:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
if json:
|
||||
return jsonify(auth_status.to_dict()), auth_status.value.code
|
||||
|
||||
if auth_status == AuthStatus.NO_USERS:
|
||||
return redirect(
|
||||
f'/register?redirect={redirect_page or request.url}', 307
|
||||
)
|
||||
|
||||
if auth_status == AuthStatus.UNAUTHORIZED:
|
||||
return redirect(f'/login?redirect={redirect_page or request.url}', 307)
|
||||
|
||||
return Response(
|
||||
'Authentication required',
|
||||
401,
|
||||
{'WWW-Authenticate': 'Basic realm="Login required"'},
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def get_auth_status(req, skip_auth_methods=None) -> AuthStatus:
|
||||
"""
|
||||
Check against the available authentication methods (except those listed in
|
||||
``skip_auth_methods``) if the user is properly authenticated.
|
||||
"""
|
||||
|
||||
n_users = user_manager.get_user_count()
|
||||
skip_methods = skip_auth_methods or []
|
||||
|
||||
# User/pass HTTP authentication
|
||||
http_auth_ok = True
|
||||
if n_users > 0 and 'http' not in skip_methods:
|
||||
http_auth_ok = authenticate_user_pass(req)
|
||||
if http_auth_ok:
|
||||
return AuthStatus.OK
|
||||
|
||||
# Token-based authentication
|
||||
token_auth_ok = True
|
||||
if 'token' not in skip_methods:
|
||||
token_auth_ok = authenticate_token(req)
|
||||
if token_auth_ok:
|
||||
return AuthStatus.OK
|
||||
|
||||
# Session token based authentication
|
||||
session_auth_ok = True
|
||||
if n_users > 0 and 'session' not in skip_methods:
|
||||
return AuthStatus.OK if authenticate_session(req) else AuthStatus.UNAUTHORIZED
|
||||
|
||||
# At least a user should be created before accessing an authenticated resource
|
||||
if n_users == 0 and 'session' not in skip_methods:
|
||||
return AuthStatus.NO_USERS
|
||||
|
||||
if ( # pylint: disable=too-many-boolean-expressions
|
||||
('http' not in skip_methods and http_auth_ok)
|
||||
or ('token' not in skip_methods and token_auth_ok)
|
||||
or ('session' not in skip_methods and session_auth_ok)
|
||||
):
|
||||
return AuthStatus.OK
|
||||
|
||||
return AuthStatus.UNAUTHORIZED
|
21
platypush/backend/http/app/utils/auth/status.py
Normal file
21
platypush/backend/http/app/utils/auth/status.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
|
||||
|
||||
StatusValue = namedtuple('StatusValue', ['code', 'message'])
|
||||
|
||||
|
||||
class AuthStatus(Enum):
|
||||
"""
|
||||
Models the status of the authentication.
|
||||
"""
|
||||
|
||||
OK = StatusValue(200, 'OK')
|
||||
UNAUTHORIZED = StatusValue(401, 'Unauthorized')
|
||||
NO_USERS = StatusValue(412, 'Please create a user first')
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'code': self.value[0],
|
||||
'message': self.value[1],
|
||||
}
|
63
platypush/backend/http/app/utils/bus.py
Normal file
63
platypush/backend/http/app/utils/bus.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
from flask import current_app
|
||||
from redis import Redis
|
||||
|
||||
from platypush.bus.redis import RedisBus
|
||||
from platypush.config import Config
|
||||
from platypush.message import Message
|
||||
from platypush.message.request import Request
|
||||
from platypush.utils import get_redis_queue_name_by_message
|
||||
|
||||
from .logger import logger
|
||||
|
||||
_bus = None
|
||||
|
||||
|
||||
def bus():
|
||||
global _bus
|
||||
if _bus is None:
|
||||
_bus = RedisBus(redis_queue=current_app.config.get('redis_queue'))
|
||||
return _bus
|
||||
|
||||
|
||||
def send_message(msg, wait_for_response=True):
|
||||
msg = Message.build(msg)
|
||||
if msg is None:
|
||||
return
|
||||
|
||||
if isinstance(msg, Request):
|
||||
msg.origin = 'http'
|
||||
|
||||
if Config.get('token'):
|
||||
msg.token = Config.get('token')
|
||||
|
||||
bus().post(msg)
|
||||
|
||||
if isinstance(msg, Request) and wait_for_response:
|
||||
response = get_message_response(msg)
|
||||
logger().debug('Processing response on the HTTP backend: {}'.format(response))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def send_request(action, wait_for_response=True, **kwargs):
|
||||
msg = {'type': 'request', 'action': action}
|
||||
|
||||
if kwargs:
|
||||
msg['args'] = kwargs
|
||||
|
||||
return send_message(msg, wait_for_response=wait_for_response)
|
||||
|
||||
|
||||
def get_message_response(msg):
|
||||
redis = Redis(**bus().redis_args)
|
||||
redis_queue = get_redis_queue_name_by_message(msg)
|
||||
if not redis_queue:
|
||||
return
|
||||
|
||||
response = redis.blpop(redis_queue, timeout=60)
|
||||
if response and len(response) > 1:
|
||||
response = Message.build(response[1])
|
||||
else:
|
||||
response = None
|
||||
|
||||
return response
|
31
platypush/backend/http/app/utils/logger.py
Normal file
31
platypush/backend/http/app/utils/logger.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
import logging
|
||||
|
||||
from platypush.config import Config
|
||||
|
||||
_logger = None
|
||||
|
||||
|
||||
def logger():
|
||||
global _logger
|
||||
if not _logger:
|
||||
log_args = {
|
||||
'level': logging.INFO,
|
||||
'format': '%(asctime)-15s|%(levelname)5s|%(name)s|%(message)s',
|
||||
}
|
||||
|
||||
level = (Config.get('backend.http') or {}).get('logging') or (
|
||||
Config.get('logging') or {}
|
||||
).get('level')
|
||||
filename = (Config.get('backend.http') or {}).get('filename')
|
||||
|
||||
if level:
|
||||
log_args['level'] = (
|
||||
getattr(logging, level.upper()) if isinstance(level, str) else level
|
||||
)
|
||||
if filename:
|
||||
log_args['filename'] = filename
|
||||
|
||||
logging.basicConfig(**log_args)
|
||||
_logger = logging.getLogger('platypush:web')
|
||||
|
||||
return _logger
|
59
platypush/backend/http/app/utils/routes.py
Normal file
59
platypush/backend/http/app/utils/routes.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import pkgutil
|
||||
|
||||
from platypush.backend import Backend
|
||||
from platypush.config import Config
|
||||
from platypush.utils import get_ip_or_hostname
|
||||
|
||||
from .logger import logger
|
||||
|
||||
|
||||
def get_http_port():
|
||||
from platypush.backend.http import HttpBackend
|
||||
|
||||
http_conf = Config.get('backend.http') or {}
|
||||
return http_conf.get('port', HttpBackend._DEFAULT_HTTP_PORT)
|
||||
|
||||
|
||||
def get_routes():
|
||||
base_pkg = '.'.join([Backend.__module__, 'http', 'app', 'routes'])
|
||||
base_dir = os.path.join(
|
||||
os.path.dirname(inspect.getfile(Backend)), 'http', 'app', 'routes'
|
||||
)
|
||||
routes = []
|
||||
|
||||
for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=base_pkg + '.'):
|
||||
try:
|
||||
module = importlib.import_module(mod_name)
|
||||
if hasattr(module, '__routes__'):
|
||||
routes.extend(module.__routes__)
|
||||
except Exception as e:
|
||||
logger.warning('Could not import module %s', mod_name)
|
||||
logger.exception(e)
|
||||
continue
|
||||
|
||||
return routes
|
||||
|
||||
|
||||
def get_local_base_url():
|
||||
http_conf = Config.get('backend.http') or {}
|
||||
bind_address = http_conf.get('bind_address')
|
||||
if not bind_address or bind_address == '0.0.0.0':
|
||||
bind_address = 'localhost'
|
||||
|
||||
return '{proto}://{host}:{port}'.format(
|
||||
proto=('https' if http_conf.get('ssl_cert') else 'http'),
|
||||
host=bind_address,
|
||||
port=get_http_port(),
|
||||
)
|
||||
|
||||
|
||||
def get_remote_base_url():
|
||||
http_conf = Config.get('backend.http') or {}
|
||||
return '{proto}://{host}:{port}'.format(
|
||||
proto=('https' if http_conf.get('ssl_cert') else 'http'),
|
||||
host=get_ip_or_hostname(),
|
||||
port=get_http_port(),
|
||||
)
|
|
@ -5,18 +5,20 @@ from typing import List, Type
|
|||
|
||||
import pkgutil
|
||||
|
||||
from ._base import WSRoute, logger
|
||||
from ..ws import WSRoute, logger
|
||||
|
||||
|
||||
def scan_routes() -> List[Type[WSRoute]]:
|
||||
def get_ws_routes() -> List[Type[WSRoute]]:
|
||||
"""
|
||||
Scans for websocket route objects.
|
||||
"""
|
||||
from platypush.backend.http import HttpBackend
|
||||
|
||||
base_dir = os.path.dirname(__file__)
|
||||
base_pkg = '.'.join([HttpBackend.__module__, 'app', 'ws'])
|
||||
base_dir = os.path.join(os.path.dirname(inspect.getfile(HttpBackend)), 'app', 'ws')
|
||||
routes = []
|
||||
|
||||
for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=__package__ + '.'):
|
||||
for _, mod_name, _ in pkgutil.walk_packages([base_dir], prefix=base_pkg + '.'):
|
||||
try:
|
||||
module = importlib.import_module(mod_name)
|
||||
except Exception as e:
|
3
platypush/backend/http/app/ws/__init__.py
Normal file
3
platypush/backend/http/app/ws/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from ._base import WSRoute, logger, pubsub_redis_topic
|
||||
|
||||
__all__ = ['WSRoute', 'logger', 'pubsub_redis_topic']
|
|
@ -1,4 +1,5 @@
|
|||
from abc import ABC, abstractclassmethod
|
||||
import json
|
||||
from logging import getLogger
|
||||
from threading import RLock, Thread
|
||||
from typing import Any, Generator, Iterable, Optional, Union
|
||||
|
@ -8,7 +9,9 @@ from redis import ConnectionError as RedisConnectionError
|
|||
from tornado.ioloop import IOLoop
|
||||
from tornado.websocket import WebSocketHandler
|
||||
|
||||
from platypush.backend.http.app.utils.auth import AuthStatus, get_auth_status
|
||||
from platypush.config import Config
|
||||
from platypush.message import Message
|
||||
from platypush.utils import get_redis
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -32,7 +35,14 @@ class WSRoute(WebSocketHandler, Thread, ABC):
|
|||
|
||||
@override
|
||||
def open(self, *_, **__):
|
||||
logger.info('Client %s connected to %s', self.request.remote_ip, self.path())
|
||||
auth_status = get_auth_status(self.request)
|
||||
if auth_status != AuthStatus.OK:
|
||||
self.close(code=1008, reason=auth_status.value.message) # Policy Violation
|
||||
return
|
||||
|
||||
logger.info(
|
||||
'Client %s connected to %s', self.request.remote_ip, self.request.path
|
||||
)
|
||||
self.name = f'ws:{self.app_name()}@{self.request.remote_ip}'
|
||||
self.start()
|
||||
|
||||
|
@ -52,6 +62,10 @@ class WSRoute(WebSocketHandler, Thread, ABC):
|
|||
def path(cls) -> str:
|
||||
return f'/ws/{cls.app_name()}'
|
||||
|
||||
@property
|
||||
def auth_required(self):
|
||||
return True
|
||||
|
||||
def subscribe(self, *topics: str) -> None:
|
||||
with self._sub_lock:
|
||||
for topic in topics:
|
||||
|
@ -78,7 +92,12 @@ class WSRoute(WebSocketHandler, Thread, ABC):
|
|||
except RedisConnectionError:
|
||||
return
|
||||
|
||||
def send(self, msg: Union[str, bytes]) -> None:
|
||||
def send(self, msg: Union[str, bytes, dict, list, tuple, set]) -> None:
|
||||
if isinstance(msg, (list, tuple, set)):
|
||||
msg = list(msg)
|
||||
if isinstance(msg, (list, dict)):
|
||||
msg = json.dumps(msg, cls=Message.Encoder)
|
||||
|
||||
self._io_loop.asyncio_loop.call_soon_threadsafe( # type: ignore
|
||||
self.write_message, msg
|
||||
)
|
||||
|
@ -99,7 +118,7 @@ class WSRoute(WebSocketHandler, Thread, ABC):
|
|||
logger.info(
|
||||
'Client %s disconnected from %s, reason=%s, message=%s',
|
||||
self.request.remote_ip,
|
||||
self.path(),
|
||||
self.request.path,
|
||||
self.close_code,
|
||||
self.close_reason,
|
||||
)
|
|
@ -1,4 +0,0 @@
|
|||
from ._base import WSRoute, logger, pubsub_redis_topic
|
||||
from ._scanner import scan_routes
|
||||
|
||||
__all__ = ['WSRoute', 'logger', 'pubsub_redis_topic', 'scan_routes']
|
Loading…
Add table
Reference in a new issue