Commit 27e52d2c authored by Vít Novotný's avatar Vít Novotný
Browse files

Add `NerModel.test()`

parent e8e1d2d4
......@@ -49,6 +49,27 @@ class NerModel:
def __repr__(self) -> str:
return '{}: {}'.format(self.__class__.__name__, self)
def test(self, test_tagged_sentence_basename: str) -> float:
lang_module = LangModule(self.model_name_or_basename)
ner_test_texts, ner_test_labels = load_ner_dataset(test_tagged_sentence_basename)
ner_evaluators = [MeanFScore(evaluation_strategy='steps')]
ner_objective = TokenClassification(lang_module,
batch_size=1,
texts_or_path=ner_test_texts,
labels_or_path=ner_test_labels,
val_texts_or_path=ner_testn_texts,
val_labels_or_path=ner_testn_labels,
val_evaluators=ner_evaluators)
adaptation_arguments = AdaptationArguments()
schedule = get_schedule('sequential', [ner_objective], adaptation_arguments)
adapter = Adapter(lang_module, schedule, adaptation_arguments)
test_result = adapter.evaluate()
test_f_score = test_result['eval_MeanFScore']
return test_f_score
@classmethod
def train_and_save(cls, model_checkpoint_basename: str, model_basename: str,
training_sentence_basename: str, validation_sentence_basename: str,
......@@ -71,14 +92,6 @@ class NerModel:
texts_or_path=mlm_training_texts,
val_texts_or_path=mlm_validation_texts)
# Set up named entity recognition (NER) training
def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], List[NerTags]]:
ner_texts, all_ner_tags = [], []
for tagged_sentence in TaggedSentence.load(tagged_sentence_basename):
ner_texts.append(tagged_sentence.sentence)
all_ner_tags.append(tagged_sentence.bio_ner_tags)
return ner_texts, all_ner_tags
ner_training_texts, ner_training_labels = load_ner_dataset(training_tagged_sentence_basename)
ner_validation_texts, ner_validation_labels = load_ner_dataset(validation_tagged_sentence_basename)
ner_validation_texts = ner_validation_texts[:cls.NUM_VALIDATION_SAMPLES]
......@@ -131,3 +144,11 @@ class NerModel:
model_pathname = model_pathname / 'TokenClassification'
model_name_or_basename = str(model_pathname)
return cls(model_name_or_basename)
def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], List[NerTags]]:
ner_texts, all_ner_tags = [], []
for tagged_sentence in TaggedSentence.load(tagged_sentence_basename):
ner_texts.append(tagged_sentence.sentence)
all_ner_tags.append(tagged_sentence.bio_ner_tags)
return ner_texts, all_ner_tags
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment