Commit 7263450d authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Make `NerModel.__init__()` lazy and add `str()` and `repr()`

parent b78567bb
Loading
Loading
Loading
Loading
+12 −2
Original line number Diff line number Diff line
@@ -36,8 +36,18 @@ class NerModel:
    NUM_VALIDATION_SAMPLES = CONFIG.getint('number_of_validation_samples')
    STOPPING_PATIENCE = CONFIG.getint('stopping_patience')

    def __init__(self, model: AutoModelForTokenClassification):
        self.model = model
    def __init__(self, model_name_or_basename: str):
        self.model_name_or_basename = model_name_or_basename

    @property
    def model(self) -> AutoModelForTokenClassification:
        model = AutoModelForTokenClassification.from_pretrained(self.model_name_or_basename)

    def __str__(self) -> str:
        return self.model_name_or_basename

    def __repr__(self) -> str:
        return '{}: {}'.format(self.__class__.__name__, self)

    @classmethod
    def train_and_save(cls, model_checkpoint_basename: str, model_basename: str,