Verified Commit adae7d92 authored by Peter Stanko's avatar Peter Stanko
Browse files

Added authentication for the worker

parent 10324e42
Loading
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ class ContextPaths:

class AppContext:
    def __init__(self):
        self._auth_validator = None
        self._app = None
        from kontr_worker import tools, execution
        import celery
@@ -31,6 +32,13 @@ class AppContext:
    def paths(self) -> ContextPaths:
        return ContextPaths(context=self)

    @property
    def auth_validator(self):
        if self._auth_validator is None:
            from kontr_worker.service import auth
            self._auth_validator = auth.AuthValidator()
        return self._auth_validator

    @property
    def submissions(self) -> SubmissionsCollection:
        return self._submissions
@@ -51,6 +59,7 @@ class AppContext:
        self.redis.init_app(self._app)
        async_celery.init_app(self._app)
        self.portal.init_app(self._app)
        self.auth_validator.init_app(self.app)

    @property
    def app(self) -> flask.Flask:
+1 −0
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ class TestConfig(Config):
    PORTAL_URL = 'http://localhost'
    PORTAL_WORKER_SECRET = 'test_token'
    WORKSPACE_DIR = '/tmp/testing_workspace'
    WORKER_SECRET_FOR_PORTAL = 'secret-testing-token'


class ProductionConfig(Config):
+8 −2
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ Rest layer module
from flask import Flask
from flask_restplus import Api


API_PREFIX = "/api/v1.0"


@@ -34,6 +35,11 @@ def register_namespaces(app: Flask):
    from .submissions import submissions_namespace, result_namespace
    rest_api.add_namespace(submissions_namespace)
    rest_api.add_namespace(result_namespace)

    rest_api.init_app(app)
    register_error_handlers(app)
    return app


def register_error_handlers(app: Flask):
    from kontr_worker.rest import errors
    errors.load_errors(app)
+45 −3
Original line number Diff line number Diff line
import logging

import flask
from flask import Flask
from flask_restplus import abort

from kontr_worker.rest import rest_api

log = logging.getLogger(__name__)


def load_errors(app: Flask):
    log.debug("[LOAD] Custom error handlers loaded")
    for (ex, func) in rest_api.error_handlers.items():
        app.register_error_handler(ex, func)


@rest_api.errorhandler
def default_error_handler():
    return flask.jsonify({'message': 'Default error handler has been triggered'}), 400


class WorkerError(Exception):
    pass
    def __init__(self, message=None):
        self.message = message


class WorkerApiError(WorkerError):
    pass
    def __init__(self, code, message=None):
        super().__init__(message=message)
        self.code = code


class DataMissingError(WorkerApiError):
@@ -11,5 +36,22 @@ class DataMissingError(WorkerApiError):


class SubmissionError(WorkerApiError):
    def __init__(self, code, message):
        super(SubmissionError, self).__init__(message=message, code=code)


class AuthorizationFailed(WorkerApiError):
    def __init__(self, message):
        super(SubmissionError, self).__init__(message=message)
        super().__init__(message=message, code=401)


@rest_api.errorhandler(WorkerError)
def handle_general_worker_error(ex: WorkerError):
    log.error(f"[WORKER] Worker error: {ex} ")
    abort(code=400, message=ex.message)


@rest_api.errorhandler(WorkerApiError)
def handle_general_worker_api_error(ex: WorkerError):
    log.error(f"[API] Api error: {ex} ")
    abort(code=ex.code, message=ex.message)
+2 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import logging
from flask_restplus import Namespace, Resource

from kontr_worker.extensions import context
from kontr_worker.tools.decorators import require_authorization

management_namespace = Namespace('management')
log = logging.getLogger(__name__)
@@ -9,6 +10,7 @@ log = logging.getLogger(__name__)

@management_namespace.route('/status')
class StatusResource(Resource):
    @require_authorization
    def get(self):
        """Receives submission params as JSON
        Returns:
Loading