Commit 3b279510 authored by Ondřej Borýsek's avatar Ondřej Borýsek
Browse files

Refactor single threaded init

parent 4de4a8d9
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -41,7 +41,7 @@ def download_api_examples():

@bp.route("/force_upload_init_data", methods=["GET"])
def force_upload_init_data():
    pwndoc_db_init.InitialData(force=True)
    pwndoc_db_init.InitialData().upload_initial_data()
    return basic_msg("API inited")


+28 −15
Original line number Diff line number Diff line
import random
import time

from flask import flash, redirect, url_for
from loguru import logger
import helpers.custom_logging
@@ -11,9 +14,7 @@ from helpers.file_utils import *


def create_app(skip_pwndoc: bool = False):
    import pwndoc_db_init
    from helpers.db_models import db
    from helpers.template_grouping import TemplateGrouping

    from api_process_findings import bp as process_findings_bp
    from api_templates import bp as templates_bp
@@ -25,12 +26,6 @@ def create_app(skip_pwndoc: bool = False):
    logger.add("user_data/importer-logs/main.log", rotation="1 week", retention="60 days", compression="gz", encoding="utf8")
    Path(relative_path('api_examples/')).mkdir(parents=True, exist_ok=True)

    pwndoc_db_init.PwnDocUpdate.setup_scan2report_plugin_folder()
    main_thread = True
    if not skip_pwndoc:
        main_thread = pwndoc_db_init.InitialData().main_thread
        pwndoc_db_init.PwnDocUpdate.add_new_fields_if_needed()

    app = Flask(__name__, static_url_path=f"{config.IMPORTER_URL_PREFIX}/static")

    if config.BEHIND_REVERSE_PROXY:
@@ -50,14 +45,11 @@ def create_app(skip_pwndoc: bool = False):
        else:
            migrate.init_app(app, db)

    # db.create_all() need to happen single threaded (or staggered)
    time.sleep(random.random())  # todo: do this properly
    with app.app_context():
        db.create_all()

    if not skip_pwndoc and main_thread:
        # todo: ensure only one thread does this?
        with app.app_context():
            pwndoc_db_init.PwnDocUpdate.upload_templates_from_scan2report_repository()

    app.config['UPLOAD_FOLDER'] = config.FLASK_UPLOAD_FOLDER
    app.config['MAX_CONTENT_LENGTH'] = 100 * 1000 * 1000
    app.config["JSONIFY_PRETTYPRINT_REGULAR"] = True
@@ -65,6 +57,8 @@ def create_app(skip_pwndoc: bool = False):

    Path(app.config['UPLOAD_FOLDER']).mkdir(parents=True, exist_ok=True)

    single_threaded_init(app, skip_pwndoc)

    app.register_blueprint(process_findings_bp, url_prefix=f'{config.IMPORTER_URL_PREFIX}/findings')
    app.register_blueprint(templates_bp, url_prefix=f'{config.IMPORTER_URL_PREFIX}/templates')
    app.register_blueprint(pwndoc_bp, url_prefix=f'{config.IMPORTER_URL_PREFIX}/pwndoc_audit')
@@ -84,8 +78,6 @@ def create_app(skip_pwndoc: bool = False):
    # Todo: possibly disable this or require additional password
    app.register_blueprint(debug_bp, url_prefix=f'{config.IMPORTER_URL_PREFIX}/debug')

    TemplateGrouping.add_new_gids()

    @app.errorhandler(500)
    def internal_error(error):
        flash(f"Internal server error (more information is in server logs)", "danger")
@@ -106,6 +98,27 @@ def create_app(skip_pwndoc: bool = False):
    return app


def single_threaded_init(app, skip_pwndoc: bool):
    from helpers.template_grouping import TemplateGrouping
    import pwndoc_db_init

    main_thread = True if skip_pwndoc else pwndoc_db_init.InitialData().run_init()

    if not main_thread:
        return  # todo: All threads except the main one will be responsive before the full setup is completed. User actions could cause race condition.

    pwndoc_db_init.PwnDocUpdate.setup_scan2report_plugin_folder()  # This will fail if run second time.
    TemplateGrouping.add_new_gids()

    if skip_pwndoc:
        return

    pwndoc_db_init.PwnDocUpdate.add_new_fields_if_needed()

    with app.app_context():
        pwndoc_db_init.PwnDocUpdate.upload_templates_from_scan2report_repository()


if __name__ == "__main__":
    app = create_app()
    app.run()
+15 −20
Original line number Diff line number Diff line
@@ -20,32 +20,30 @@ from pwndoc_api import session, test_connection

class InitialData:
    __NEW_MAPPING_FILEPATH = relative_path(f'user_data/pwndoc-init/new_id_mapping.json')
    __INIT_INDICATOR_NAME = "INIT_STATUS"
    __INIT_FINISHED = "INIT_FINISHED"

    def __init__(self, force=False):
    def __init__(self):
        self.main_thread: bool = False
        time.sleep(random.uniform(0, 1))  # randomly offset threads to not overload pwndoc

        self._save_new_id_mapping("NO_CONTENT", "LOREM_IPSUM")  # force file creation if doesn't exist
        was_db_already_initialized = self._load_new_id_mapping(self.__INIT_INDICATOR_NAME) == self.__INIT_FINISHED

        is_db_clean = self._is_db_clean()
        force |= is_db_clean
    def run_init(self) -> bool:
        self.setup_first_user()  # this can set self.main_thread
        # login happens automatically
        self.upload_initial_data()
        logger.info("PwnDoc Init completed")
        return self.main_thread

        if was_db_already_initialized and not force:
    def setup_first_user(self):
        if not self._is_db_clean(max_seconds=30):  # checks if DB is empty, wait if necessary
            logger.info("DB seems to be (at least partially) initialized - skipping initialization.")
            return

        self.main_thread = self._add_pwndoc_user(first_user=True)
        # login happens automatically

        # The thread that created the user will continue. Other threads will block until pwndoc is setup (or timeout).
        # The thread that created the user will continue. Other threads will block until user is created (or timeout).
        if not self.main_thread:
            self.wait_until_pwndoc_setup()
            self.wait_until_pwndoc_user_is_ready(wait_max_x_seconds=30)
            return

        # no dependencies
    def upload_initial_data(self):
        self._upload_universal('_api_templates.json')
        self._upload_universal('_api_data_languages.json')
        self._upload_universal('_api_data_sections.json')
@@ -53,18 +51,15 @@ class InitialData:
        self._upload_universal('_api_data_audit-types.json')  # requires: templates, languages, sections
        self._upload_universal('_api_data_custom-fields.json')  # requires: everything

        self._save_new_id_mapping(self.__INIT_INDICATOR_NAME, self.__INIT_FINISHED)
        logger.info("PwnDoc Init completed")

    @staticmethod
    def wait_until_pwndoc_setup(wait_max_x_seconds=30):
    def wait_until_pwndoc_user_is_ready(wait_max_x_seconds=30):
        while wait_max_x_seconds > 0 and test_connection() != 200:
            time.sleep(1)
            wait_max_x_seconds -= 1
        assert test_connection() == 200, "After 30 seconds everything should be ready, but isn't. Restarting thread."

    @staticmethod
    def _is_db_clean():
    def _is_db_clean(max_seconds: float = 30) -> bool:
        for i in range(10):
            try:
                resp = session.get(f"{config.PWNDOC_URL}/api/users/init")
@@ -76,7 +71,7 @@ class InitialData:
                if config.APP_INIT_PWNDOC_DISALLOW_WAITING:
                    break
                logger.warning(f"Connection to {config.PWNDOC_URL} failed. Waiting before retry.")
                time.sleep(random.uniform(1, 2))
                time.sleep(max_seconds/10 + random.uniform(-0.1, 0.1))

        logger.warning(f"Connection to {config.PWNDOC_URL} failed. Not retrying.")
        return False