Commit 98350de5 authored by Vít Novotný's avatar Vít Novotný
Browse files

Make `NerModel.train_and_save()` return model

parent 7263450d
......@@ -54,7 +54,7 @@ class NerModel:
training_sentence_basename: str, validation_sentence_basename: str,
training_tagged_sentence_basename: str,
validation_tagged_sentence_basename: str,
schedule_name: Optional[ScheduleName] = None) -> None:
schedule_name: Optional[ScheduleName] = None) -> 'NerModel':
if schedule_name is None:
schedule_name = cls.SCHEDULE_NAME
......@@ -123,9 +123,11 @@ class NerModel:
model_pathname = cls.ROOT_PATH / model_basename
adapter.save_model(str(model_pathname))
return cls.load(model_basename)
@classmethod
def load(cls, model_basename: str) -> 'NerModel':
def load(cls, basename: str) -> 'NerModel':
model_pathname = cls.ROOT_PATH / model_basename
ner_model_pathname = model_pathname / 'TokenClassification'
ner_model = AutoModelForTokenClassification.from_pretrained(str(ner_model_pathname))
return cls(ner_model)
model_pathname = model_pathname / 'TokenClassification'
model_name_or_basename = str(model_pathname)
return cls(model_name_or_basename)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment