Commit 27e52d2c authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Add `NerModel.test()`

parent e8e1d2d4
Loading
Loading
Loading
Loading
+29 −8
Original line number Diff line number Diff line
@@ -49,6 +49,27 @@ class NerModel:
    def __repr__(self) -> str:
        return '{}: {}'.format(self.__class__.__name__, self)

    def test(self, test_tagged_sentence_basename: str) -> float:
        lang_module = LangModule(self.model_name_or_basename)

        ner_test_texts, ner_test_labels = load_ner_dataset(test_tagged_sentence_basename)
        ner_evaluators = [MeanFScore(evaluation_strategy='steps')]
        ner_objective = TokenClassification(lang_module,
                                            batch_size=1,
                                            texts_or_path=ner_test_texts,
                                            labels_or_path=ner_test_labels,
                                            val_texts_or_path=ner_testn_texts,
                                            val_labels_or_path=ner_testn_labels,
                                            val_evaluators=ner_evaluators)

        adaptation_arguments = AdaptationArguments()
        schedule = get_schedule('sequential', [ner_objective], adaptation_arguments)
        adapter = Adapter(lang_module, schedule, adaptation_arguments)

        test_result = adapter.evaluate()
        test_f_score = test_result['eval_MeanFScore']
        return test_f_score

    @classmethod
    def train_and_save(cls, model_checkpoint_basename: str, model_basename: str,
                       training_sentence_basename: str, validation_sentence_basename: str,
@@ -71,14 +92,6 @@ class NerModel:
                                               texts_or_path=mlm_training_texts,
                                               val_texts_or_path=mlm_validation_texts)

        # Set up named entity recognition (NER) training
        def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], List[NerTags]]:
            ner_texts, all_ner_tags = [], []
            for tagged_sentence in TaggedSentence.load(tagged_sentence_basename):
                ner_texts.append(tagged_sentence.sentence)
                all_ner_tags.append(tagged_sentence.bio_ner_tags)
            return ner_texts, all_ner_tags

        ner_training_texts, ner_training_labels = load_ner_dataset(training_tagged_sentence_basename)
        ner_validation_texts, ner_validation_labels = load_ner_dataset(validation_tagged_sentence_basename)
        ner_validation_texts = ner_validation_texts[:cls.NUM_VALIDATION_SAMPLES]
@@ -131,3 +144,11 @@ class NerModel:
        model_pathname = model_pathname / 'TokenClassification'
        model_name_or_basename = str(model_pathname)
        return cls(model_name_or_basename)


def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], List[NerTags]]:
    ner_texts, all_ner_tags = [], []
    for tagged_sentence in TaggedSentence.load(tagged_sentence_basename):
        ner_texts.append(tagged_sentence.sentence)
        all_ner_tags.append(tagged_sentence.bio_ner_tags)
    return ner_texts, all_ner_tags