Commit 0b717e6d authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Add `BIOTokenPunctuationStrippingClassification`

parent 224e461d
Loading
Loading
Loading
Loading
+5 −0
Original line number Original line Diff line number Diff line
@@ -6,6 +6,10 @@ from .model import (
    NerModel
    NerModel
)
)


from .objective import (
    BIOTokenPunctuationStrippingClassification,
)

from .schedule import (
from .schedule import (
    get_schedule,
    get_schedule,
    ScheduleName,
    ScheduleName,
@@ -15,6 +19,7 @@ from .schedule import (


__all__ = [
__all__ = [
    'AggregateMeanFScoreEvaluator',
    'AggregateMeanFScoreEvaluator',
    'BIOTokenPunctuationStrippingClassification',
    'get_schedule',
    'get_schedule',
    'NerModel',
    'NerModel',
    'ScheduleName',
    'ScheduleName',
+6 −3
Original line number Original line Diff line number Diff line
@@ -11,6 +11,7 @@ from ..config import CONFIG as _CONFIG




CategoryName = str
CategoryName = str
CategoryNames = Iterable[CategoryName]
GroupName = str
GroupName = str
GroupNames = Iterable[GroupName]
GroupNames = Iterable[GroupName]
Category = int
Category = int
@@ -39,13 +40,13 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
        super().__init__(*args, **kwargs)
        super().__init__(*args, **kwargs)


    def __call__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizer,
    def __call__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizer,
                 dataset: AdaptationDataset) -> FScore:
                 dataset: AdaptationDataset, ignored_index: Category = -100) -> FScore:
        expected_labels, actual_labels = self._collect_token_predictions(model, dataset)
        expected_labels, actual_labels = self._collect_token_predictions(model, dataset)


        mean_f_score, total_number_of_samples = 0, 0
        mean_f_score, total_number_of_samples = 0, 0
        for group_name in self.group_names:
        for group_name in self.group_names:
            number_of_samples, f_score = self.get_f_score(
            number_of_samples, f_score = self.get_f_score(
                self.GROUPS[group_name], expected_labels, actual_labels)
                self.GROUPS[group_name], expected_labels, actual_labels, ignored_index)
            mean_f_score += number_of_samples * f_score
            mean_f_score += number_of_samples * f_score
            total_number_of_samples += number_of_samples
            total_number_of_samples += number_of_samples
        if total_number_of_samples > 0:
        if total_number_of_samples > 0:
@@ -54,7 +55,7 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
        return mean_f_score
        return mean_f_score


    def get_f_score(self, group: Group, expected_labels: List[Category],
    def get_f_score(self, group: Group, expected_labels: List[Category],
                    actual_labels: List[Category]) -> Tuple[int, FScore]:
                    actual_labels: List[Category], ignored_index: Category) -> Tuple[int, FScore]:
        expected_categories: Set[Category] = {
        expected_categories: Set[Category] = {
            self.category_map[category]
            self.category_map[category]
            for category
            for category
@@ -64,6 +65,8 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):


        true_positives, false_positives, false_negatives = 0, 0, 0
        true_positives, false_positives, false_negatives = 0, 0, 0
        for expected_label, actual_label in zip_equal(expected_labels, actual_labels):
        for expected_label, actual_label in zip_equal(expected_labels, actual_labels):
            if expected_label == ignored_index:
                continue
            if expected_label in expected_categories and actual_label in expected_categories:
            if expected_label in expected_categories and actual_label in expected_categories:
                true_positives += 1
                true_positives += 1
            elif expected_label not in expected_categories and actual_label in expected_categories:
            elif expected_label not in expected_categories and actual_label in expected_categories:
+18 −16
Original line number Original line Diff line number Diff line
@@ -6,7 +6,6 @@ from typing import Tuple, List, Optional, Iterable, Dict


import comet_ml  # noqa: F401
import comet_ml  # noqa: F401
from adaptor.adapter import Adapter
from adaptor.adapter import Adapter
from adaptor.objectives.classification import TokenClassification
from adaptor.objectives.MLM import MaskedLanguageModeling
from adaptor.objectives.MLM import MaskedLanguageModeling
from adaptor.lang_module import LangModule
from adaptor.lang_module import LangModule
from adaptor.utils import StoppingStrategy, AdaptationArguments
from adaptor.utils import StoppingStrategy, AdaptationArguments
@@ -17,6 +16,7 @@ from ..document import Document, Sentence
from ..search import TaggedSentence, NerTags
from ..search import TaggedSentence, NerTags
from .schedule import ScheduleName, get_schedule
from .schedule import ScheduleName, get_schedule
from .evaluator import AggregateMeanFScoreEvaluator, FScore, CategoryMap, CategoryName
from .evaluator import AggregateMeanFScoreEvaluator, FScore, CategoryMap, CategoryName
from .objective import BIOTokenPunctuationStrippingClassification




LOGGER = getLogger(__name__)
LOGGER = getLogger(__name__)
@@ -65,7 +65,8 @@ class NerModel:


        ner_test_texts, ner_test_labels = load_ner_dataset(test_tagged_sentence_basename)
        ner_test_texts, ner_test_labels = load_ner_dataset(test_tagged_sentence_basename)
        ner_evaluators = list(get_evaluators(self.labels))
        ner_evaluators = list(get_evaluators(self.labels))
        ner_objective = TokenClassification(lang_module,
        ner_objective = BIOTokenPunctuationStrippingClassification(
            lang_module,
            batch_size=1,
            batch_size=1,
            texts_or_path=['placeholder text'],
            texts_or_path=['placeholder text'],
            labels_or_path=[' '.join(self.labels)],
            labels_or_path=[' '.join(self.labels)],
@@ -116,7 +117,8 @@ class NerModel:
        ner_validation_labels = ner_validation_labels[:cls.NUM_VALIDATION_SAMPLES]
        ner_validation_labels = ner_validation_labels[:cls.NUM_VALIDATION_SAMPLES]


        ner_evaluators = list(get_evaluators(cls.LABELS))
        ner_evaluators = list(get_evaluators(cls.LABELS))
        ner_objective = TokenClassification(lang_module,
        ner_objective = BIOTokenPunctuationStrippingClassification(
            lang_module,
            batch_size=cls.BATCH_SIZE,
            batch_size=cls.BATCH_SIZE,
            texts_or_path=ner_training_texts,
            texts_or_path=ner_training_texts,
            labels_or_path=ner_training_labels,
            labels_or_path=ner_training_labels,
@@ -159,7 +161,7 @@ class NerModel:
    @classmethod
    @classmethod
    def load(cls, model_basename: str) -> 'NerModel':
    def load(cls, model_basename: str) -> 'NerModel':
        model_pathname = cls.ROOT_PATH / model_basename
        model_pathname = cls.ROOT_PATH / model_basename
        model_pathname = model_pathname / 'TokenClassification'
        model_pathname = model_pathname / 'BIOTokenPunctuationStrippingClassification'
        model_name_or_basename = str(model_pathname)
        model_name_or_basename = str(model_pathname)
        return cls(model_name_or_basename)
        return cls(model_name_or_basename)


+145 −0
Original line number Original line Diff line number Diff line
from itertools import islice
from typing import Dict, Iterable

from adaptor.objectives.classification import TokenClassification
import regex
import torch
from transformers import DataCollatorForTokenClassification

from .evaluator import CategoryName, CategoryNames, Category


Token = str


class BIOTokenPunctuationStrippingClassification(TokenClassification):
    def _wordpiece_token_label_alignment(self,
                                         texts: CategoryNames,
                                         labels: CategoryNames,
                                         label_all_tokens: bool = True,
                                         ignore_label_idx: Category = -100) -> Iterable[Dict[str, torch.LongTensor]]:
        texts, labels = list(texts), list(labels)

        collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8)
        batch_features = []

        # special tokens identification: general heuristic
        ids1 = self.tokenizer("X").input_ids
        ids2 = self.tokenizer("Y").input_ids

        special_bos_tokens = []
        for i in range(len(ids1)):
            if ids1[i] == ids2[i]:
                special_bos_tokens.append(ids1[i])
            else:
                break

        special_eos_tokens = []
        for i in range(1, len(ids1)):
            if ids1[-i] == ids2[-i]:
                special_eos_tokens.append(ids1[-i])
            else:
                break
        special_eos_tokens = list(reversed(special_eos_tokens))

        # per-sample iteration
        for text, text_labels in zip(texts, labels):
            tokens = text.split()
            labels = text_labels.split()

            assert len(tokens) == len(labels), \
                "A number of tokens in the first line is different than a number of labels. " \
                "Text: %s \nLabels: %s" % (text, text_labels)

            tokens_ids = self.tokenizer(tokens, truncation=True, add_special_tokens=False).input_ids

            wpiece_ids = special_bos_tokens.copy()

            # labels of BoS and EoS are always "other"
            out_label_ids = [ignore_label_idx] * len(special_bos_tokens)

            def get_label_type(label: CategoryName) -> CategoryName:
                if label.startswith('B-') or label.startswith('I-'):
                    label_type = label[2:]
                else:
                    label_type = label
                return label_type

            def get_label_ids(head_label: CategoryName) -> Iterable[Category]:
                head_label_id = self.labels_map[head_label]
                tail_label = f'I-{get_label_type(head_label)}' if head_label.startswith('B-') else head_label
                tail_label_id = self.labels_map[tail_label]

                yield head_label_id
                while True:
                    yield tail_label_id

            def is_punctuation(token: Token) -> bool:
                punctuation_regex = r'^\W*$'
                punctuation_match = regex.match(punctuation_regex, token)
                is_punctuation = punctuation_match is not None
                return is_punctuation

            def strip_trailing_punctuation(label_ids: Iterable[Category],
                                           tokens: Iterable[Token]) -> Iterable[Category]:
                label_ids, tokens = list(label_ids), list(tokens)
                assert len(label_ids) == len(tokens)
                for token_index, token in reversed(list(enumerate(tokens))):
                    if is_punctuation(token):
                        label_ids[token_index] = self.labels_map['O']
                    else:
                        break
                return label_ids

            for label_index, (token_ids, label) in enumerate(zip(tokens_ids, labels)):
                # chain the wordpieces without the special symbols for each token
                wpiece_ids.extend(token_ids)
                if label_all_tokens:
                    # label all wordpieces
                    label_ids = get_label_ids(label)
                    label_ids = islice(label_ids, len(token_ids))
                    if label != 'O':
                        if label_index + 1 >= len(labels):
                            should_strip_punctuation = False
                        else:
                            label_type = get_label_type(label)
                            next_label = labels[label_index + 1]
                            next_label_type = get_label_type(next_label)
                            should_strip_punctuation = label_type != next_label_type
                        if should_strip_punctuation:
                            tokens = self.tokenizer.batch_decode(token_ids)
                            label_ids = strip_trailing_punctuation(label_ids, tokens)
                    out_label_ids.extend(label_ids)
                else:
                    # label only the first wordpiece
                    out_label_ids.append(self.labels_map[label])
                    # ignore the predictions over other token's wordpieces from the loss
                    out_label_ids.extend([ignore_label_idx] * (len(token_ids) - 1))

            out_label_ids.extend([ignore_label_idx] * len(special_eos_tokens))
            wpiece_ids.extend(special_eos_tokens.copy())

            assert len(out_label_ids) == len(wpiece_ids), \
                "We found misaligned labels in sample: '%s'" % text

            if self.tokenizer.model_max_length is None:
                truncated_size = len(out_label_ids)
            else:
                truncated_size = min(self.tokenizer.model_max_length, len(out_label_ids))

            batch_features.append({"input_ids": wpiece_ids[:truncated_size],
                                   "attention_mask": [1] * truncated_size,
                                   "labels": out_label_ids[:truncated_size]})
            # maybe yield a batch
            if len(batch_features) == self.batch_size:
                yield collator(batch_features)
                batch_features = []
        if batch_features:
            yield collator(batch_features)

        # check that the number of outputs of the selected compatible head matches the just-parsed data set
        num_outputs = list(self.compatible_head_model.parameters())[-1].shape[0]
        num_labels = len(self.labels_map)
        assert num_outputs == num_labels, "A number of the outputs for the selected %s head (%s) " \
                                          "does not match a number of token labels (%s)" \
                                          % (self.compatible_head, num_outputs, num_labels)