Commit faea9c48 authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Add `recognition.model.NerModel`

parent 6332e5db
Loading
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -87,3 +87,13 @@ number_of_results = 10
[search.Evaluation]
number_of_results = 10
f_score_beta = 0.25

[recognition.NerModel]
root_path = /nlp/projekty/ahisto/public_html/named-entity-search/results/
base_model = xlm-roberta-base
batch_size = 4
gradient_accumulation_steps = 4
log_every_n_steps = 100
evaluate_every_n_steps = 10000
save_every_n_steps = 10000
number_of_training_epochs = 10
+8 −0
Original line number Diff line number Diff line
from .model import (
    NerModel
)


__all__ = [
    'NerModel',
]
+128 −0
Original line number Diff line number Diff line
from __future__ import annotations

from logging import getLogger
from pathlib import Path
from typing import Tuple, List

import comet_ml  # noqa: F401
from adaptor.adapter import Adapter
from adaptor.objectives.classification import TokenClassification
from adaptor.evaluators.token_classification import MeanFScore
from adaptor.objectives.MLM import MaskedLanguageModeling
from adaptor.lang_module import LangModule
from adaptor.schedules import SequentialSchedule
from adaptor.utils import StoppingStrategy, AdaptationArguments
from more_itertools import zip_equal
import regex
from transformers import AutoModelForTokenClassification

from ..config import CONFIG as _CONFIG
from ..document import Document, Sentence
from ..search import TaggedSentence, NerTags


CONFIG = _CONFIG['recognition.NerModel']
LOGGER = getLogger(__name__)


class NerModel:
    ROOT_PATH = Path(CONFIG['root_path'])
    BASE_MODEL = CONFIG['base_model']
    BATCH_SIZE = CONFIG.getint('batch_size')
    GRADIENT_ACCUMULATION_STEPS = CONFIG.getint('gradient_accumulation_steps')
    EVAL_STEPS = CONFIG.getint('evaluate_every_n_steps')
    SAVE_STEPS = CONFIG.getint('save_every_n_steps')
    LOGGING_STEPS = CONFIG.getint('log_every_n_steps')
    NUM_TRAIN_EPOCHS = CONFIG.getint('number_of_training_epochs')

    def __init__(self, model: AutoModelForTokenClassification):
        self.model = model

    @classmethod
    def train_and_save(cls, model_checkpoint_basename: str, model_basename: str,
                       training_sentence_basename: str, validation_sentence_basename: str,
                       training_tagged_sentence_basename: str,
                       validation_tagged_sentence_basename: str) -> None:

        lang_module = LangModule(cls.BASE_MODEL)

        # Set up masked language modeling (MLM) training
        mlm_training_texts = list(Document.load_sentences(training_sentence_basename))
        mlm_validation_texts = list(Document.load_sentences(validation_sentence_basename))

        mlm_objective = MaskedLanguageModeling(lang_module,
                                               batch_size=cls.BATCH_SIZE,
                                               texts_or_path=mlm_training_texts,
                                               val_texts_or_path=mlm_validation_texts)

        # Set up named entity recognition (NER) training
        def load_ner_dataset_and_remove_tokens_with_only_punctuation(
                tagged_sentence_basename: str) -> Tuple[List[Sentence], List[NerTags]]:
            ner_texts, all_ner_tags = [], []
            for tagged_sentence in TaggedSentence.load(tagged_sentence_basename):
                ner_token_list, ner_tag_list = [], []
                for ner_token, ner_tag in zip_equal(tagged_sentence.sentence.split(),
                                                    tagged_sentence.ner_tags.split()):
                    if regex.fullmatch(r'\W+', ner_token):
                        continue
                    ner_token_list.append(ner_token)
                    ner_tag_list.append(ner_tag)
                if not ner_token_list:
                    continue
                ner_text = ' '.join(ner_token_list)
                ner_tags = ' '.join(ner_tag_list)
                ner_texts.append(ner_text)
                all_ner_tags.append(ner_tags)
            return ner_texts, all_ner_tags

        ner_training_texts, ner_training_labels = \
            load_ner_dataset_and_remove_tokens_with_only_punctuation(
                training_tagged_sentence_basename)

        ner_validation_texts, ner_validation_labels = \
            load_ner_dataset_and_remove_tokens_with_only_punctuation(
                validation_tagged_sentence_basename)

        ner_evaluators = [MeanFScore(decides_convergence=True)]
        ner_objective = TokenClassification(lang_module,
                                            batch_size=cls.BATCH_SIZE,
                                            texts_or_path=ner_training_texts,
                                            labels_or_path=ner_training_labels,
                                            val_texts_or_path=ner_validation_texts,
                                            val_labels_or_path=ner_validation_labels,
                                            val_evaluators=ner_evaluators)

        # Train MLM and NER in parallel until convergence on validation
        model_checkpoint_pathname = cls.ROOT_PATH / model_checkpoint_basename
        adaptation_arguments = AdaptationArguments(
            output_dir=str(model_checkpoint_pathname),
            stopping_strategy=StoppingStrategy.FIRST_OBJECTIVE_CONVERGED,
            evaluation_strategy='steps',
            eval_steps=cls.EVAL_STEPS,
            save_strategy='steps',
            save_steps=cls.SAVE_STEPS,
            logging_strategy='steps',
            logging_steps=cls.LOGGING_STEPS,
            do_train=True,
            do_eval=True,
            gradient_accumulation_steps=cls.GRADIENT_ACCUMULATION_STEPS,
            num_train_epochs=cls.NUM_TRAIN_EPOCHS,
            fp16=True,
            fp16_full_eval=True,
        )

        schedule = SequentialSchedule([mlm_objective, ner_objective], adaptation_arguments)

        adapter = Adapter(lang_module, schedule, adaptation_arguments)
        adapter.train()

        # Save NER model
        model_pathname = cls.ROOT_PATH / model_basename
        adapter.save_model(str(model_pathname))

    @classmethod
    def load(cls, model_basename: str) -> 'NerModel':
        model_pathname = cls.ROOT_PATH / model_basename
        ner_model_pathname = model_pathname / 'TokenClassification'
        ner_model = AutoModelForTokenClassification.from_pretrained(str(ner_model_pathname))
        return cls(ner_model)
+2 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ from .evaluation import (
)

from .result import (
    NerTags,
    SearchResultList,
    TaggedSentence,
)
@@ -21,5 +22,6 @@ __all__ = [
    'Evaluation',
    'Search',
    'SearchResultList',
    'NerTags',
    'TaggedSentence',
]
+2 −0
Original line number Diff line number Diff line
adaptor==0.1.6
bert-score==0.3.11
comet_ml
edit_distance~=1.0.4
gensim~=4.1.2
humanize~=3.2.0
Loading