Commit 590372a4 authored by Vít Novotný's avatar Vít Novotný
Browse files

Prevent future objectives from affecting early stopping

parent c4c8ffd3
......@@ -10,42 +10,48 @@ from ..config import CONFIG as _CONFIG
ScheduleName = str
class FairSequentialSchedule(SequentialSchedule):
class FairSequentialSchedule(Schedule):
CONFIG = _CONFIG['recognition.FairSequentialSchedule']
MAX_NUM_TRAIN_EPOCHS = CONFIG.getint('maximum_number_of_training_epochs_per_objective')
label = 'fair_sequential'
def _sample_objectives(self, split: str) -> Iterable[Objective]:
assert split == 'train'
while True:
for objective in self.objectives[split].values():
starting_epoch = objective.epoch
for _ in range(objective.dataset_length[split]):
if objective in self.converged_objectives and not self.args.log_converged_objectives:
continue
if split == 'train':
num_train_epochs = objective.epoch - starting_epoch
if num_train_epochs >= self.MAX_NUM_TRAIN_EPOCHS:
continue
num_train_epochs = objective.epoch - starting_epoch
if num_train_epochs >= self.MAX_NUM_TRAIN_EPOCHS:
continue
yield objective
class FineTuningSchedule(SequentialSchedule):
class FineTuningSchedule(Schedule):
CONFIG = _CONFIG['recognition.FineTuningSchedule']
MAX_NUM_TRAIN_EPOCHS = CONFIG.getint('maximum_number_of_training_epochs_per_objective')
label = 'fine_tuning'
def _sample_objectives(self, split: str) -> Iterable[Objective]:
for objective in self.objectives[split].values():
assert split == 'train'
objectives = self.objectives[split].values()
remaining_objectives = list(objectives)
for objective in objectives:
remaining_objectives = remaining_objectives[1:]
starting_epoch = objective.epoch
while True:
# Prevent future objectives from affecting early stopping
for remaining_objective in remaining_objectives:
remaining_objective.evaluations_history["eval"] = {}
if objective in self.converged_objectives and not self.args.log_converged_objectives:
break
if split == 'train':
num_train_epochs = objective.epoch - starting_epoch
if num_train_epochs >= self.MAX_NUM_TRAIN_EPOCHS:
break
num_train_epochs = objective.epoch - starting_epoch
if num_train_epochs >= self.MAX_NUM_TRAIN_EPOCHS:
break
yield objective
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment