Commit 0b717e6d authored by Vít Novotný's avatar Vít Novotný
Browse files

Add `BIOTokenPunctuationStrippingClassification`

parent 224e461d
......@@ -6,6 +6,10 @@ from .model import (
NerModel
)
from .objective import (
BIOTokenPunctuationStrippingClassification,
)
from .schedule import (
get_schedule,
ScheduleName,
......@@ -15,6 +19,7 @@ from .schedule import (
__all__ = [
'AggregateMeanFScoreEvaluator',
'BIOTokenPunctuationStrippingClassification',
'get_schedule',
'NerModel',
'ScheduleName',
......
......@@ -11,6 +11,7 @@ from ..config import CONFIG as _CONFIG
CategoryName = str
CategoryNames = Iterable[CategoryName]
GroupName = str
GroupNames = Iterable[GroupName]
Category = int
......@@ -39,13 +40,13 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
super().__init__(*args, **kwargs)
def __call__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizer,
dataset: AdaptationDataset) -> FScore:
dataset: AdaptationDataset, ignored_index: Category = -100) -> FScore:
expected_labels, actual_labels = self._collect_token_predictions(model, dataset)
mean_f_score, total_number_of_samples = 0, 0
for group_name in self.group_names:
number_of_samples, f_score = self.get_f_score(
self.GROUPS[group_name], expected_labels, actual_labels)
self.GROUPS[group_name], expected_labels, actual_labels, ignored_index)
mean_f_score += number_of_samples * f_score
total_number_of_samples += number_of_samples
if total_number_of_samples > 0:
......@@ -54,7 +55,7 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
return mean_f_score
def get_f_score(self, group: Group, expected_labels: List[Category],
actual_labels: List[Category]) -> Tuple[int, FScore]:
actual_labels: List[Category], ignored_index: Category) -> Tuple[int, FScore]:
expected_categories: Set[Category] = {
self.category_map[category]
for category
......@@ -64,6 +65,8 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
true_positives, false_positives, false_negatives = 0, 0, 0
for expected_label, actual_label in zip_equal(expected_labels, actual_labels):
if expected_label == ignored_index:
continue
if expected_label in expected_categories and actual_label in expected_categories:
true_positives += 1
elif expected_label not in expected_categories and actual_label in expected_categories:
......
......@@ -6,7 +6,6 @@ from typing import Tuple, List, Optional, Iterable, Dict
import comet_ml # noqa: F401
from adaptor.adapter import Adapter
from adaptor.objectives.classification import TokenClassification
from adaptor.objectives.MLM import MaskedLanguageModeling
from adaptor.lang_module import LangModule
from adaptor.utils import StoppingStrategy, AdaptationArguments
......@@ -17,6 +16,7 @@ from ..document import Document, Sentence
from ..search import TaggedSentence, NerTags
from .schedule import ScheduleName, get_schedule
from .evaluator import AggregateMeanFScoreEvaluator, FScore, CategoryMap, CategoryName
from .objective import BIOTokenPunctuationStrippingClassification
LOGGER = getLogger(__name__)
......@@ -65,13 +65,14 @@ class NerModel:
ner_test_texts, ner_test_labels = load_ner_dataset(test_tagged_sentence_basename)
ner_evaluators = list(get_evaluators(self.labels))
ner_objective = TokenClassification(lang_module,
batch_size=1,
texts_or_path=['placeholder text'],
labels_or_path=[' '.join(self.labels)],
val_texts_or_path=ner_test_texts,
val_labels_or_path=ner_test_labels,
val_evaluators=ner_evaluators)
ner_objective = BIOTokenPunctuationStrippingClassification(
lang_module,
batch_size=1,
texts_or_path=['placeholder text'],
labels_or_path=[' '.join(self.labels)],
val_texts_or_path=ner_test_texts,
val_labels_or_path=ner_test_labels,
val_evaluators=ner_evaluators)
adaptation_arguments = AdaptationArguments(
output_dir='.',
stopping_strategy=StoppingStrategy.FIRST_OBJECTIVE_CONVERGED,
......@@ -116,13 +117,14 @@ class NerModel:
ner_validation_labels = ner_validation_labels[:cls.NUM_VALIDATION_SAMPLES]
ner_evaluators = list(get_evaluators(cls.LABELS))
ner_objective = TokenClassification(lang_module,
batch_size=cls.BATCH_SIZE,
texts_or_path=ner_training_texts,
labels_or_path=ner_training_labels,
val_texts_or_path=ner_validation_texts,
val_labels_or_path=ner_validation_labels,
val_evaluators=ner_evaluators)
ner_objective = BIOTokenPunctuationStrippingClassification(
lang_module,
batch_size=cls.BATCH_SIZE,
texts_or_path=ner_training_texts,
labels_or_path=ner_training_labels,
val_texts_or_path=ner_validation_texts,
val_labels_or_path=ner_validation_labels,
val_evaluators=ner_evaluators)
# Train MLM and NER in parallel until convergence on validation
model_checkpoint_pathname = cls.ROOT_PATH / model_checkpoint_basename
......@@ -159,7 +161,7 @@ class NerModel:
@classmethod
def load(cls, model_basename: str) -> 'NerModel':
model_pathname = cls.ROOT_PATH / model_basename
model_pathname = model_pathname / 'TokenClassification'
model_pathname = model_pathname / 'BIOTokenPunctuationStrippingClassification'
model_name_or_basename = str(model_pathname)
return cls(model_name_or_basename)
......
from itertools import islice
from typing import Dict, Iterable
from adaptor.objectives.classification import TokenClassification
import regex
import torch
from transformers import DataCollatorForTokenClassification
from .evaluator import CategoryName, CategoryNames, Category
Token = str
class BIOTokenPunctuationStrippingClassification(TokenClassification):
def _wordpiece_token_label_alignment(self,
texts: CategoryNames,
labels: CategoryNames,
label_all_tokens: bool = True,
ignore_label_idx: Category = -100) -> Iterable[Dict[str, torch.LongTensor]]:
texts, labels = list(texts), list(labels)
collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8)
batch_features = []
# special tokens identification: general heuristic
ids1 = self.tokenizer("X").input_ids
ids2 = self.tokenizer("Y").input_ids
special_bos_tokens = []
for i in range(len(ids1)):
if ids1[i] == ids2[i]:
special_bos_tokens.append(ids1[i])
else:
break
special_eos_tokens = []
for i in range(1, len(ids1)):
if ids1[-i] == ids2[-i]:
special_eos_tokens.append(ids1[-i])
else:
break
special_eos_tokens = list(reversed(special_eos_tokens))
# per-sample iteration
for text, text_labels in zip(texts, labels):
tokens = text.split()
labels = text_labels.split()
assert len(tokens) == len(labels), \
"A number of tokens in the first line is different than a number of labels. " \
"Text: %s \nLabels: %s" % (text, text_labels)
tokens_ids = self.tokenizer(tokens, truncation=True, add_special_tokens=False).input_ids
wpiece_ids = special_bos_tokens.copy()
# labels of BoS and EoS are always "other"
out_label_ids = [ignore_label_idx] * len(special_bos_tokens)
def get_label_type(label: CategoryName) -> CategoryName:
if label.startswith('B-') or label.startswith('I-'):
label_type = label[2:]
else:
label_type = label
return label_type
def get_label_ids(head_label: CategoryName) -> Iterable[Category]:
head_label_id = self.labels_map[head_label]
tail_label = f'I-{get_label_type(head_label)}' if head_label.startswith('B-') else head_label
tail_label_id = self.labels_map[tail_label]
yield head_label_id
while True:
yield tail_label_id
def is_punctuation(token: Token) -> bool:
punctuation_regex = r'^\W*$'
punctuation_match = regex.match(punctuation_regex, token)
is_punctuation = punctuation_match is not None
return is_punctuation
def strip_trailing_punctuation(label_ids: Iterable[Category],
tokens: Iterable[Token]) -> Iterable[Category]:
label_ids, tokens = list(label_ids), list(tokens)
assert len(label_ids) == len(tokens)
for token_index, token in reversed(list(enumerate(tokens))):
if is_punctuation(token):
label_ids[token_index] = self.labels_map['O']
else:
break
return label_ids
for label_index, (token_ids, label) in enumerate(zip(tokens_ids, labels)):
# chain the wordpieces without the special symbols for each token
wpiece_ids.extend(token_ids)
if label_all_tokens:
# label all wordpieces
label_ids = get_label_ids(label)
label_ids = islice(label_ids, len(token_ids))
if label != 'O':
if label_index + 1 >= len(labels):
should_strip_punctuation = False
else:
label_type = get_label_type(label)
next_label = labels[label_index + 1]
next_label_type = get_label_type(next_label)
should_strip_punctuation = label_type != next_label_type
if should_strip_punctuation:
tokens = self.tokenizer.batch_decode(token_ids)
label_ids = strip_trailing_punctuation(label_ids, tokens)
out_label_ids.extend(label_ids)
else:
# label only the first wordpiece
out_label_ids.append(self.labels_map[label])
# ignore the predictions over other token's wordpieces from the loss
out_label_ids.extend([ignore_label_idx] * (len(token_ids) - 1))
out_label_ids.extend([ignore_label_idx] * len(special_eos_tokens))
wpiece_ids.extend(special_eos_tokens.copy())
assert len(out_label_ids) == len(wpiece_ids), \
"We found misaligned labels in sample: '%s'" % text
if self.tokenizer.model_max_length is None:
truncated_size = len(out_label_ids)
else:
truncated_size = min(self.tokenizer.model_max_length, len(out_label_ids))
batch_features.append({"input_ids": wpiece_ids[:truncated_size],
"attention_mask": [1] * truncated_size,
"labels": out_label_ids[:truncated_size]})
# maybe yield a batch
if len(batch_features) == self.batch_size:
yield collator(batch_features)
batch_features = []
if batch_features:
yield collator(batch_features)
# check that the number of outputs of the selected compatible head matches the just-parsed data set
num_outputs = list(self.compatible_head_model.parameters())[-1].shape[0]
num_labels = len(self.labels_map)
assert num_outputs == num_labels, "A number of the outputs for the selected %s head (%s) " \
"does not match a number of token labels (%s)" \
% (self.compatible_head, num_outputs, num_labels)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment