Commit cd8adc77 authored by Vít Novotný's avatar Vít Novotný
Browse files

Rename CategoryName and Category to Label and LabelId

parent f8df7140
Pipeline #147689 passed with stage
in 10 minutes and 13 seconds
......@@ -8,15 +8,15 @@ import torch
from transformers import PreTrainedTokenizer
from ..config import CONFIG as _CONFIG
from ..search import BioNerTag as CategoryName
from ..search import BioNerTag as Label
CategoryNames = Iterable[CategoryName]
Labels = Iterable[Label]
GroupName = str
GroupNames = Iterable[GroupName]
Category = int
Group = Set[CategoryName]
CategoryMap = Dict[CategoryName, Category]
LabelId = int
Group = Set[Label]
LabelMap = Dict[Label, LabelId]
GroupMap = Dict[GroupName, Group]
FScore = float
......@@ -33,14 +33,14 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
'O': {'O', 'B-MISC', 'I-MISC'},
}
def __init__(self, category_map: CategoryMap, group_names: Optional[GroupNames],
def __init__(self, label_map: LabelMap, group_names: Optional[GroupNames],
*args, **kwargs):
self.category_map = category_map
self.label_map = label_map
self.group_names = self.DEFAULT_GROUP_NAMES if group_names is None else tuple(group_names)
super().__init__(*args, **kwargs)
def __call__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizer,
dataset: AdaptationDataset, ignored_index: Category = -100) -> FScore:
dataset: AdaptationDataset, ignored_index: LabelId = -100) -> FScore:
expected_labels, actual_labels = self._collect_token_predictions(model, dataset)
mean_f_score, total_number_of_samples = 0, 0
......@@ -54,13 +54,13 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
return mean_f_score
def get_f_score(self, group: Group, expected_labels: List[Category],
actual_labels: List[Category], ignored_index: Category) -> Tuple[int, FScore]:
expected_categories: Set[Category] = {
self.category_map[category]
for category
def get_f_score(self, group: Group, expected_labels: List[LabelId],
actual_labels: List[LabelId], ignored_index: LabelId) -> Tuple[int, FScore]:
expected_categories: Set[LabelId] = {
self.label_map[line_id]
for line_id
in group
if category in self.category_map
if line_id in self.label_map
}
true_positives, false_positives, false_negatives = 0, 0, 0
......
......@@ -15,7 +15,7 @@ from ..config import CONFIG as _CONFIG
from ..document import Document, Sentence
from ..search import TaggedSentence, BioNerTags
from .schedule import ScheduleName, get_schedule
from .evaluator import AggregateMeanFScoreEvaluator, FScore, CategoryMap, CategoryName
from .evaluator import AggregateMeanFScoreEvaluator, FScore, LabelMap, Label
from .objective import BIOTokenPunctuationStrippingClassification
......@@ -38,7 +38,7 @@ class NerModel:
SCHEDULE_NAME = CONFIG['schedule']
NUM_VALIDATION_SAMPLES = CONFIG.getint('number_of_validation_samples')
STOPPING_PATIENCE = CONFIG.getint('stopping_patience')
LABELS: Iterable[CategoryName] = ('B-PER', 'I-PER', 'B-LOC', 'I-LOC', 'O')
LABELS: Iterable[Label] = ('B-PER', 'I-PER', 'B-LOC', 'I-LOC', 'O')
def __init__(self, model_name_or_basename: str, labels: Iterable[str] = LABELS):
self.model_name_or_basename = model_name_or_basename
......@@ -175,11 +175,11 @@ def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], Lis
def get_evaluators(labels: Iterable[str]) -> Iterable[AggregateMeanFScoreEvaluator]:
category_map: CategoryMap = {
category: category_index
for category_index, category
label_map: LabelMap = {
line_id: line_id_index
for line_id_index, line_id
in enumerate(sorted(labels))
}
for group_name in sorted(AggregateMeanFScoreEvaluator.GROUPS.keys()):
yield AggregateMeanFScoreEvaluator(category_map, [group_name], decides_convergence=False)
yield AggregateMeanFScoreEvaluator(category_map, None, decides_convergence=True)
yield AggregateMeanFScoreEvaluator(label_map, [group_name], decides_convergence=False)
yield AggregateMeanFScoreEvaluator(label_map, None, decides_convergence=True)
......@@ -6,7 +6,7 @@ import regex
import torch
from transformers import DataCollatorForTokenClassification
from .evaluator import CategoryName, CategoryNames, Category
from .evaluator import Label, Labels, LabelId
Token = str
......@@ -14,10 +14,10 @@ Token = str
class BIOTokenPunctuationStrippingClassification(TokenClassification):
def _wordpiece_token_label_alignment(self,
texts: CategoryNames,
labels: CategoryNames,
texts: Labels,
labels: Labels,
label_all_tokens: bool = True,
ignore_label_idx: Category = -100) -> Iterable[Dict[str, torch.LongTensor]]:
ignore_label_idx: LabelId = -100) -> Iterable[Dict[str, torch.LongTensor]]:
texts, labels = list(texts), list(labels)
collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8)
......@@ -58,14 +58,14 @@ class BIOTokenPunctuationStrippingClassification(TokenClassification):
# labels of BoS and EoS are always "other"
out_label_ids = [ignore_label_idx] * len(special_bos_tokens)
def get_label_type(label: CategoryName) -> CategoryName:
def get_label_type(label: Label) -> Label:
if label.startswith('B-') or label.startswith('I-'):
label_type = label[2:]
else:
label_type = label
return label_type
def get_label_ids(head_label: CategoryName) -> Iterable[Category]:
def get_label_ids(head_label: Label) -> Iterable[LabelId]:
head_label_id = self.labels_map[head_label]
tail_label = f'I-{get_label_type(head_label)}' if head_label.startswith('B-') else head_label
tail_label_id = self.labels_map[tail_label]
......@@ -80,8 +80,8 @@ class BIOTokenPunctuationStrippingClassification(TokenClassification):
is_punctuation = punctuation_match is not None
return is_punctuation
def strip_trailing_punctuation(label_ids: Iterable[Category],
tokens: Iterable[Token]) -> Iterable[Category]:
def strip_trailing_punctuation(label_ids: Iterable[LabelId],
tokens: Iterable[Token]) -> Iterable[LabelId]:
label_ids, tokens = list(label_ids), list(tokens)
assert len(label_ids) == len(tokens)
for token_index, token in reversed(list(enumerate(tokens))):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment