Commit 63feb75b authored by Vít Novotný's avatar Vít Novotný
Browse files

Add `ahisto_named_entity_search.recognition.evaluator`

Adds the `AggregateMeanFScore` class, which can be used to evaluate NER
models on different categories and to aggregate the scores for clusters
of categories such as B-PER and I-PER.
parent 171526e6
Pipeline #147217 passed with stage
in 8 minutes and 42 seconds
from .evaluator import (
from .model import (
......@@ -10,6 +14,7 @@ from .schedule import (
__all__ = [
from typing import Dict, Optional, Set, Iterable, List
from more_itertools import zip_equal
import torch
from transformers import PreTrainedTokenizer
from adaptor.evaluators.token_classification import TokenClassificationEvaluator
from adaptor.utils import AdaptationDataset
CategoryName = str
GroupName = str
Category = int
Group = Iterable[CategoryName]
CategoryMap = Dict[CategoryName, Category]
GroupMap = Dict[GroupName, Group]
FScore = float
class AggregateMeanFScore(TokenClassificationEvaluator):
GROUPS: GroupMap = {
'PER': ('B-PER', 'I-PER'),
'LOC': ('B-LOC', 'I-LOC'),
'O': ('O', ),
def __init__(self, category_map: CategoryMap, group_name: Optional[GroupName],
*args, **kwargs):
self.category_map = category_map
self.group_name = group_name
super().__init__(*args, **kwargs)
def __call__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizer,
dataset: AdaptationDataset) -> FScore:
expected_labels, actual_labels = self._collect_token_predictions(model, dataset)
if self.group_name is None:
f_scores = [
self.get_f_score(self.GROUPS[group_name], expected_labels, actual_labels)
for group_name
in self.__class__.get_all_group_names()
assert len(f_scores) > 0
mean_f_score = sum(f_scores) / len(f_scores)
group = self.GROUPS[self.group_name]
mean_f_score = self.get_f_score(group, expected_labels, actual_labels)
return mean_f_score
def get_f_score(self, group: Group, expected_labels: List[Category],
actual_labels: List[Category]) -> FScore:
expected_categories: Set[Category] = {self.category_map[category] for category in group}
true_positives, false_positives, false_negatives = 0, 0, 0
for expected_label, actual_label in zip_equal(expected_labels, actual_labels):
if expected_label in expected_categories and actual_label in expected_categories:
true_positives += 1
elif expected_label not in expected_categories and actual_label in expected_categories:
false_positives += 1
elif expected_label in expected_categories and actual_label not in expected_categories:
false_negatives += 1
f_score = true_positives / (true_positives + (0.5 * (false_positives + false_negatives)))
return f_score
def get_all_group_names(cls) -> Iterable[GroupName]:
for group_name in cls.GROUPS.keys():
yield group_name
def __str__(self) -> str:
if self.group_name is None:
return 'all'
return self.group_name
......@@ -2,12 +2,11 @@ from __future__ import annotations
from logging import getLogger
from pathlib import Path
from typing import Tuple, List, Optional, Iterable
from typing import Tuple, List, Optional, Iterable, Dict
import comet_ml # noqa: F401
from adaptor.adapter import Adapter
from adaptor.objectives.classification import TokenClassification
from adaptor.evaluators.token_classification import MeanFScore
from adaptor.objectives.MLM import MaskedLanguageModeling
from adaptor.lang_module import LangModule
from adaptor.utils import StoppingStrategy, AdaptationArguments
......@@ -17,12 +16,13 @@ from ..config import CONFIG as _CONFIG
from ..document import Document, Sentence
from import TaggedSentence, NerTags
from .schedule import ScheduleName, get_schedule
from .evaluator import AggregateMeanFScore, FScore, CategoryMap, CategoryName
LOGGER = getLogger(__name__)
EvaluationResult = float
EvaluationResult = Dict[AggregateMeanFScore, FScore]
class NerModel:
......@@ -38,7 +38,7 @@ class NerModel:
NUM_VALIDATION_SAMPLES = CONFIG.getint('number_of_validation_samples')
STOPPING_PATIENCE = CONFIG.getint('stopping_patience')
LABELS = ['B-PER', 'I-PER', 'B-LOC', 'I-LOC', 'O']
LABELS: Iterable[CategoryName] = ('B-PER', 'I-PER', 'B-LOC', 'I-LOC', 'O')
def __init__(self, model_name_or_basename: str, labels: Iterable[str] = LABELS):
self.model_name_or_basename = model_name_or_basename
......@@ -64,14 +64,14 @@ class NerModel:
lang_module = LangModule(self.model_name_or_basename)
ner_test_texts, ner_test_labels = load_ner_dataset(test_tagged_sentence_basename)
ner_evaluator = MeanFScore()
ner_evaluators = list(get_evaluators(self.labels))
ner_objective = TokenClassification(lang_module,
texts_or_path=['placeholder text'],
labels_or_path=[' '.join(self.labels)],
adaptation_arguments = AdaptationArguments(
......@@ -81,8 +81,12 @@ class NerModel:
adapter = Adapter(lang_module, schedule, adaptation_arguments) # noqa: F841
test_result = ner_objective.per_objective_log("eval")
test_f_score = test_result[f'eval_{ner_objective}_{ner_evaluator}']
return test_f_score
evaluation_results = {
ner_evaluator: test_result[f'eval_{ner_objective}_{ner_evaluator}']
for ner_evaluator
in ner_evaluators
return evaluation_results
def train_and_save(cls, model_checkpoint_basename: str, model_basename: str,
......@@ -111,7 +115,7 @@ class NerModel:
ner_validation_texts = ner_validation_texts[:cls.NUM_VALIDATION_SAMPLES]
ner_validation_labels = ner_validation_labels[:cls.NUM_VALIDATION_SAMPLES]
ner_evaluators = [MeanFScore(decides_convergence=True)]
ner_evaluators = list(get_evaluators(cls.LABELS))
ner_objective = TokenClassification(lang_module,
......@@ -166,3 +170,14 @@ def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], Lis
return ner_texts, all_ner_tags
def get_evaluators(labels: Iterable[str]) -> Iterable[AggregateMeanFScore]:
category_map: CategoryMap = {
category: category_index
for category_index, category
in enumerate(sorted(labels))
for group_name in AggregateMeanFScore.get_all_group_names():
yield AggregateMeanFScore(category_map, group_name, decides_convergence=False)
yield AggregateMeanFScore(category_map, None, decides_convergence=True)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment