Commit c4c8ffd3 authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Use fine-tuning schedule to train NER models

parent 04b847a2
Loading
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -102,3 +102,6 @@ schedule = fair-sequential-schedule

[recognition.FairSequentialSchedule]
maximum_number_of_training_epochs_per_objective = 1

[recognition.FineTuningSchedule]
maximum_number_of_training_epochs_per_objective = 5
+21 −0
Original line number Diff line number Diff line
@@ -30,6 +30,25 @@ class FairSequentialSchedule(SequentialSchedule):
                    yield objective


class FineTuningSchedule(SequentialSchedule):
    CONFIG = _CONFIG['recognition.FineTuningSchedule']
    MAX_NUM_TRAIN_EPOCHS = CONFIG.getint('maximum_number_of_training_epochs_per_objective')

    label = 'fine_tuning'

    def _sample_objectives(self, split: str) -> Iterable[Objective]:
        for objective in self.objectives[split].values():
            starting_epoch = objective.epoch
            while True:
                if objective in self.converged_objectives and not self.args.log_converged_objectives:
                    break
                if split == 'train':
                    num_train_epochs = objective.epoch - starting_epoch
                    if num_train_epochs >= self.MAX_NUM_TRAIN_EPOCHS:
                        break
                yield objective


def get_schedule(schedule_name: str, objectives: Iterable[Objective],
                 adaptation_arguments: AdaptationArguments) -> Schedule:
    objectives = list(objectives)
@@ -37,6 +56,8 @@ def get_schedule(schedule_name: str, objectives: Iterable[Objective],
        schedule = SequentialSchedule(objectives, adaptation_arguments)
    elif schedule_name == 'fair-sequential':
        schedule = FairSequentialSchedule(objectives, adaptation_arguments)
    elif schedule_name == 'fine-tuning':
        schedule = FineTuningSchedule(objectives, adaptation_arguments)
    elif schedule_name == 'parallel':
        schedule = ParallelSchedule(objectives, adaptation_arguments)
    else:
+1 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ IMAGE_NAME=ahisto/named-entity-search:latest
ROOT_PATH=/nlp/projekty/ahisto/public_html/named-entity-search/results/
ANNOTATION_PATH=/nlp/projekty/ahisto/annotations/
OCR_EVAL_PATH=/nlp/projekty/ahisto/ahisto-ocr-eval
SCHEDULE_NAME=parallel
SCHEDULE_NAME=fine-tuning

DOCKER_BUILDKIT=1 docker build --build-arg UID="$(id -u)" --build-arg GID="$(id -g)" --build-arg UNAME="$(id -u -n)" . -f scripts//03_train_ner_models.Dockerfile -t "$IMAGE_NAME"