Commit ee1164d0 authored by Vít Novotný's avatar Vít Novotný
Browse files

Use micro-averaged `PER` and `LOC` F1-score as default objective

parent 51b285b6
Pipeline #147350 passed with stage
in 8 minutes and 59 seconds
......@@ -237,7 +237,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Loading documents: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 268669/268669 [00:06<00:00, 44308.95it/s]\n"
"Loading documents: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 268669/268669 [00:05<00:00, 47033.84it/s]\n"
]
}
],
......@@ -462,7 +462,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 22,
"id": "694daad3-2b04-4e3f-8bfb-bb3fe0c87dd3",
"metadata": {},
"outputs": [],
......@@ -473,7 +473,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 23,
"id": "fed4d0a4-5bc4-4af2-8e1b-c5a8a6b61c52",
"metadata": {},
"outputs": [],
......@@ -491,7 +491,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 24,
"id": "38efa732-8afd-4798-809a-ca828a8b960c",
"metadata": {},
"outputs": [],
......@@ -501,7 +501,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 25,
"id": "1c9d3bab-53de-4c39-8e4d-cd2978f00925",
"metadata": {},
"outputs": [],
......@@ -513,7 +513,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 26,
"id": "f47f62ca-1164-45b5-94b9-ff082787e8a9",
"metadata": {},
"outputs": [],
......@@ -526,7 +526,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 27,
"id": "053ff5dd-775c-431d-a988-18bf3c4f4f6d",
"metadata": {},
"outputs": [
......@@ -551,165 +551,165 @@
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PER</th>\n",
" <th>O</th>\n",
" <th>LOC</th>\n",
" <th>all</th>\n",
" <th>O</th>\n",
" <th>PER</th>\n",
" <th>PER+LOC</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>model_ner_manatee_all_only-relevant_fine-tuning</th>\n",
" <td>49.63405%</td>\n",
" <td>96.74517%</td>\n",
" <td>39.86766%</td>\n",
" <td>94.77053%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_all_parallel</th>\n",
" <td>43.69512%</td>\n",
" <td>96.65764%</td>\n",
" <td>39.60441%</td>\n",
" <td>94.52101%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_only-relevant_fine-tuning</th>\n",
" <td>49.97944%</td>\n",
" <td>96.66777%</td>\n",
" <td>41.04577%</td>\n",
" <td>94.46477%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_only-relevant_fine-tuning</th>\n",
" <td>34.07244%</td>\n",
" <td>96.38580%</td>\n",
" <td>36.68019%</td>\n",
" <td>93.57176%</td>\n",
" <td>96.66777%</td>\n",
" <td>49.97944%</td>\n",
" <td>44.30760%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_non-crossing_only-relevant_fine-tuning</th>\n",
" <td>34.74969%</td>\n",
" <td>96.34174%</td>\n",
" <td>36.64832%</td>\n",
" <td>93.41004%</td>\n",
" <th>model_ner_manatee_all_only-relevant_fine-tuning</th>\n",
" <td>39.86766%</td>\n",
" <td>96.74517%</td>\n",
" <td>49.63405%</td>\n",
" <td>43.94366%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_only-relevant_parallel</th>\n",
" <td>51.62152%</td>\n",
" <td>96.17568%</td>\n",
" <td>38.21471%</td>\n",
" <td>93.32879%</td>\n",
" <td>96.17568%</td>\n",
" <td>51.62152%</td>\n",
" <td>43.57605%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_only-relevant_parallel</th>\n",
" <td>38.21755%</td>\n",
" <td>96.10254%</td>\n",
" <td>49.71262%</td>\n",
" <td>42.94907%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_all_parallel</th>\n",
" <td>39.60441%</td>\n",
" <td>96.65764%</td>\n",
" <td>43.69512%</td>\n",
" <td>41.58149%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_all_parallel</th>\n",
" <td>44.26929%</td>\n",
" <td>96.20140%</td>\n",
" <td>34.92015%</td>\n",
" <td>93.31223%</td>\n",
" <td>96.20140%</td>\n",
" <td>44.26929%</td>\n",
" <td>38.61073%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_only-relevant_parallel</th>\n",
" <td>49.71262%</td>\n",
" <td>96.10254%</td>\n",
" <td>38.21755%</td>\n",
" <td>93.23893%</td>\n",
" <th>model_ner_fuzzy-regex_non-crossing_only-relevant_fine-tuning</th>\n",
" <td>36.64832%</td>\n",
" <td>96.34174%</td>\n",
" <td>34.74969%</td>\n",
" <td>35.69222%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_only-relevant_fine-tuning</th>\n",
" <td>36.68019%</td>\n",
" <td>96.38580%</td>\n",
" <td>34.07244%</td>\n",
" <td>35.37521%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_all_parallel</th>\n",
" <td>33.04513%</td>\n",
" <td>95.95721%</td>\n",
" <td>37.66807%</td>\n",
" <td>92.63503%</td>\n",
" <td>95.95721%</td>\n",
" <td>33.04513%</td>\n",
" <td>34.98948%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_non-crossing_all_parallel</th>\n",
" <td>33.17825%</td>\n",
" <td>95.71407%</td>\n",
" <td>32.79457%</td>\n",
" <td>92.03093%</td>\n",
" <td>95.71407%</td>\n",
" <td>33.17825%</td>\n",
" <td>32.99174%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_only-relevant_parallel</th>\n",
" <td>31.30961%</td>\n",
" <td>95.59693%</td>\n",
" <td>32.23955%</td>\n",
" <td>91.78359%</td>\n",
" <td>95.59693%</td>\n",
" <td>31.30961%</td>\n",
" <td>31.77318%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_non-crossing_only-relevant_parallel</th>\n",
" <td>32.51893%</td>\n",
" <td>95.45893%</td>\n",
" <td>30.89947%</td>\n",
" <td>91.47703%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_all_fine-tuning</th>\n",
" <td>2.17360%</td>\n",
" <td>42.86388%</td>\n",
" <td>3.80143%</td>\n",
" <td>25.11350%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_all_fine-tuning</th>\n",
" <td>2.34774%</td>\n",
" <td>23.59883%</td>\n",
" <td>2.75918%</td>\n",
" <td>13.38802%</td>\n",
" <td>95.45893%</td>\n",
" <td>32.51893%</td>\n",
" <td>31.66439%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Babelscape/wikineural-multilingual-ner baseline</th>\n",
" <td>7.35338%</td>\n",
" <td>13.35824%</td>\n",
" <td>2.84895%</td>\n",
" <td>8.07667%</td>\n",
" <td>13.35824%</td>\n",
" <td>7.35338%</td>\n",
" <td>3.57886%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_all_fine-tuning</th>\n",
" <td>2.38798%</td>\n",
" <td>7.33972%</td>\n",
" <td>3.32850%</td>\n",
" <td>4.96872%</td>\n",
" <td>7.33972%</td>\n",
" <td>2.38798%</td>\n",
" <td>2.68821%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_all_fine-tuning</th>\n",
" <td>3.80143%</td>\n",
" <td>42.86388%</td>\n",
" <td>2.17360%</td>\n",
" <td>2.60517%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_all_fine-tuning</th>\n",
" <td>2.75918%</td>\n",
" <td>23.59883%</td>\n",
" <td>2.34774%</td>\n",
" <td>2.51480%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" PER O \\\n",
"model_ner_manatee_all_only-relevant_fine-tuning 49.63405% 96.74517% \n",
"model_ner_manatee_all_all_parallel 43.69512% 96.65764% \n",
"model_ner_manatee_non-crossing_only-relevant_fi... 49.97944% 96.66777% \n",
"model_ner_fuzzy-regex_all_only-relevant_fine-tu... 34.07244% 96.38580% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 34.74969% 96.34174% \n",
"model_ner_manatee_all_only-relevant_parallel 51.62152% 96.17568% \n",
"model_ner_manatee_non-crossing_all_parallel 44.26929% 96.20140% \n",
"model_ner_manatee_non-crossing_only-relevant_pa... 49.71262% 96.10254% \n",
"model_ner_fuzzy-regex_all_all_parallel 33.04513% 95.95721% \n",
"model_ner_fuzzy-regex_non-crossing_all_parallel 33.17825% 95.71407% \n",
"model_ner_fuzzy-regex_all_only-relevant_parallel 31.30961% 95.59693% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 32.51893% 95.45893% \n",
"model_ner_manatee_all_all_fine-tuning 2.17360% 42.86388% \n",
"model_ner_manatee_non-crossing_all_fine-tuning 2.34774% 23.59883% \n",
"Babelscape/wikineural-multilingual-ner baseline 7.35338% 13.35824% \n",
"model_ner_fuzzy-regex_all_all_fine-tuning 2.38798% 7.33972% \n",
" LOC O \\\n",
"model_ner_manatee_non-crossing_only-relevant_fi... 41.04577% 96.66777% \n",
"model_ner_manatee_all_only-relevant_fine-tuning 39.86766% 96.74517% \n",
"model_ner_manatee_all_only-relevant_parallel 38.21471% 96.17568% \n",
"model_ner_manatee_non-crossing_only-relevant_pa... 38.21755% 96.10254% \n",
"model_ner_manatee_all_all_parallel 39.60441% 96.65764% \n",
"model_ner_manatee_non-crossing_all_parallel 34.92015% 96.20140% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 36.64832% 96.34174% \n",
"model_ner_fuzzy-regex_all_only-relevant_fine-tu... 36.68019% 96.38580% \n",
"model_ner_fuzzy-regex_all_all_parallel 37.66807% 95.95721% \n",
"model_ner_fuzzy-regex_non-crossing_all_parallel 32.79457% 95.71407% \n",
"model_ner_fuzzy-regex_all_only-relevant_parallel 32.23955% 95.59693% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 30.89947% 95.45893% \n",
"Babelscape/wikineural-multilingual-ner baseline 2.84895% 13.35824% \n",
"model_ner_fuzzy-regex_all_all_fine-tuning 3.32850% 7.33972% \n",
"model_ner_manatee_all_all_fine-tuning 3.80143% 42.86388% \n",
"model_ner_manatee_non-crossing_all_fine-tuning 2.75918% 23.59883% \n",
"\n",
" LOC all \n",
"model_ner_manatee_all_only-relevant_fine-tuning 39.86766% 94.77053% \n",
"model_ner_manatee_all_all_parallel 39.60441% 94.52101% \n",
"model_ner_manatee_non-crossing_only-relevant_fi... 41.04577% 94.46477% \n",
"model_ner_fuzzy-regex_all_only-relevant_fine-tu... 36.68019% 93.57176% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 36.64832% 93.41004% \n",
"model_ner_manatee_all_only-relevant_parallel 38.21471% 93.32879% \n",
"model_ner_manatee_non-crossing_all_parallel 34.92015% 93.31223% \n",
"model_ner_manatee_non-crossing_only-relevant_pa... 38.21755% 93.23893% \n",
"model_ner_fuzzy-regex_all_all_parallel 37.66807% 92.63503% \n",
"model_ner_fuzzy-regex_non-crossing_all_parallel 32.79457% 92.03093% \n",
"model_ner_fuzzy-regex_all_only-relevant_parallel 32.23955% 91.78359% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 30.89947% 91.47703% \n",
"model_ner_manatee_all_all_fine-tuning 3.80143% 25.11350% \n",
"model_ner_manatee_non-crossing_all_fine-tuning 2.75918% 13.38802% \n",
"Babelscape/wikineural-multilingual-ner baseline 2.84895% 8.07667% \n",
"model_ner_fuzzy-regex_all_all_fine-tuning 3.32850% 4.96872% "
" PER PER+LOC \n",
"model_ner_manatee_non-crossing_only-relevant_fi... 49.97944% 44.30760% \n",
"model_ner_manatee_all_only-relevant_fine-tuning 49.63405% 43.94366% \n",
"model_ner_manatee_all_only-relevant_parallel 51.62152% 43.57605% \n",
"model_ner_manatee_non-crossing_only-relevant_pa... 49.71262% 42.94907% \n",
"model_ner_manatee_all_all_parallel 43.69512% 41.58149% \n",
"model_ner_manatee_non-crossing_all_parallel 44.26929% 38.61073% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 34.74969% 35.69222% \n",
"model_ner_fuzzy-regex_all_only-relevant_fine-tu... 34.07244% 35.37521% \n",
"model_ner_fuzzy-regex_all_all_parallel 33.04513% 34.98948% \n",
"model_ner_fuzzy-regex_non-crossing_all_parallel 33.17825% 32.99174% \n",
"model_ner_fuzzy-regex_all_only-relevant_parallel 31.30961% 31.77318% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 32.51893% 31.66439% \n",
"Babelscape/wikineural-multilingual-ner baseline 7.35338% 3.57886% \n",
"model_ner_fuzzy-regex_all_all_fine-tuning 2.38798% 2.68821% \n",
"model_ner_manatee_all_all_fine-tuning 2.17360% 2.60517% \n",
"model_ner_manatee_non-crossing_all_fine-tuning 2.34774% 2.51480% "
]
},
"metadata": {},
......@@ -718,7 +718,7 @@
],
"source": [
"with pd.option_context('display.float_format', lambda mean_f_score: f'{100.0 * mean_f_score:.5f}%'):\n",
" display(f_scores_df.sort_values(by=['all'], ascending=False))"
" display(f_scores_df.sort_values(by=['PER+LOC'], ascending=False))"
]
},
{
......@@ -731,7 +731,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 28,
"id": "f7975e45-ba27-45b4-9b61-1a9c119d434d",
"metadata": {},
"outputs": [
......@@ -739,12 +739,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"/nlp/projekty/ahisto/public_html/named-entity-search/results/model_ner_manatee_all_only-relevant_fine-tuning/TokenClassification\n"
"/nlp/projekty/ahisto/public_html/named-entity-search/results/model_ner_manatee_non-crossing_only-relevant_fine-tuning/TokenClassification\n"
]
}
],
"source": [
"all_evaluator_index, = [index for index, evaluator in enumerate(evaluators) if str(evaluator) == 'all']\n",
"all_evaluator_index, = [index for index, evaluator in enumerate(evaluators) if str(evaluator) == 'PER+LOC']\n",
"best_model, _ = max(f_scores_dict.items(), key=lambda x: (x[1][all_evaluator_index], x[0]))\n",
"print(best_model)"
]
......@@ -761,7 +761,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 29,
"id": "fffd0beb-ac50-4c81-9eb3-cc225214ff63",
"metadata": {},
"outputs": [],
......@@ -777,7 +777,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 30,
"id": "b83d2a1a-4d8f-400a-a8e8-cba43fe41a83",
"metadata": {},
"outputs": [
......@@ -805,7 +805,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 31,
"id": "de04e6a9-33e5-4e85-9cfc-f0f5bc344677",
"metadata": {},
"outputs": [],
......@@ -815,7 +815,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 32,
"id": "d11a74e1-f3c8-416c-897e-e921a69dc661",
"metadata": {},
"outputs": [],
......@@ -829,7 +829,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 33,
"id": "675dd306-50b0-4245-9904-effff2432921",
"metadata": {},
"outputs": [
......@@ -862,7 +862,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 34,
"id": "2b1c528d-b4ec-4e96-8ca7-ff05fd244d80",
"metadata": {},
"outputs": [
......@@ -890,7 +890,7 @@
"- I-LOC: ##de\n",
"- I-LOC: ##ch\n"
]
},
}
],
"source": [
"tag_sentence(baseline_model, example_sentence)"
......
......@@ -106,3 +106,6 @@ maximum_number_of_training_epochs_per_objective = 1
[recognition.FineTuningSchedule]
maximum_number_of_training_epochs_per_objective = 5
[recognition.AggregateMeanFScoreEvaluator]
default_group_names = PER+LOC
from typing import Dict, Optional, Set, List, Tuple
from typing import Dict, Iterable, Optional, Set, List, Tuple
from functools import total_ordering
from adaptor.evaluators.token_classification import TokenClassificationEvaluator
from adaptor.utils import AdaptationDataset
from more_itertools import zip_equal
import torch
from transformers import PreTrainedTokenizer
from adaptor.evaluators.token_classification import TokenClassificationEvaluator
from adaptor.utils import AdaptationDataset
from ..config import CONFIG as _CONFIG
CategoryName = str
GroupName = str
GroupNames = Iterable[GroupName]
Category = int
Group = Set[CategoryName]
CategoryMap = Dict[CategoryName, Category]
......@@ -21,34 +23,32 @@ FScore = float
@total_ordering
class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
CONFIG = _CONFIG['recognition.AggregateMeanFScoreEvaluator']
DEFAULT_GROUP_NAMES = tuple(CONFIG['default_group_names'].split('+'))
GROUPS: GroupMap = {
'PER': {'B-PER', 'I-PER', 'B-ORG', 'I-ORG'},
'LOC': {'B-LOC', 'I-LOC', 'B-ORG', 'I-ORG'},
'O': {'O', 'B-MISC', 'I-MISC'},
}
def __init__(self, category_map: CategoryMap, group_name: Optional[GroupName],
def __init__(self, category_map: CategoryMap, group_names: Optional[GroupNames],
*args, **kwargs):
self.category_map = category_map
self.group_name = group_name
self.group_names = self.DEFAULT_GROUP_NAMES if group_names is None else tuple(group_names)
super().__init__(*args, **kwargs)
def __call__(self, model: torch.nn.Module, tokenizer: PreTrainedTokenizer,
dataset: AdaptationDataset) -> FScore:
expected_labels, actual_labels = self._collect_token_predictions(model, dataset)
if self.group_name is None:
mean_f_score, total_number_of_samples = 0, 0
for group_name in self.__class__.get_all_group_names():
number_of_samples, f_score = self.get_f_score(
self.GROUPS[group_name], expected_labels, actual_labels)
mean_f_score += number_of_samples * f_score
total_number_of_samples += number_of_samples
if total_number_of_samples > 0:
_, mean_f_score /= total_number_of_samples
else:
group = self.GROUPS[self.group_name]
mean_f_score = self.get_f_score(group, expected_labels, actual_labels)
mean_f_score, total_number_of_samples = 0, 0
for group_name in self.group_names:
number_of_samples, f_score = self.get_f_score(
self.GROUPS[group_name], expected_labels, actual_labels)
mean_f_score += number_of_samples * f_score
total_number_of_samples += number_of_samples
if total_number_of_samples > 0:
mean_f_score /= total_number_of_samples
return mean_f_score
......@@ -74,28 +74,21 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
f_score = true_positives / (true_positives + (0.5 * (false_positives + false_negatives)))
return number_of_samples, f_score
@classmethod
def get_all_group_names(cls) -> Set[GroupName]:
all_group_names = set(cls.GROUPS.keys())
return all_group_names
def __hash__(self):
return hash(self.group_name)
return hash(self.group_names)
def __eq__(self, other) -> bool:
if isinstance(other, AggregateMeanFScoreEvaluator):
return self.group_name == other.group_name
return self.group_names == other.group_names
return NotImplemented
def __lt__(self, other) -> bool:
if isinstance(other, AggregateMeanFScoreEvaluator):
if self.group_name is None or other.group_name is None:
return self.group_name is None
return self.group_name < other.group_name
return self.group_names < other.group_names
return NotImplemented
def __repr__(self) -> str:
return '+'.join(self.group_names)
def __str__(self) -> str:
if self.group_name is None:
return 'all'
else:
return self.group_name
return repr(self)
......@@ -178,6 +178,6 @@ def get_evaluators(labels: Iterable[str]) -> Iterable[AggregateMeanFScoreEvaluat
for category_index, category
in enumerate(sorted(labels))
}
for group_name in AggregateMeanFScoreEvaluator.get_all_group_names():
yield AggregateMeanFScoreEvaluator(category_map, group_name, decides_convergence=False)
for group_name in sorted(AggregateMeanFScoreEvaluator.GROUPS.keys()):
yield AggregateMeanFScoreEvaluator(category_map, [group_name], decides_convergence=False)
yield AggregateMeanFScoreEvaluator(category_map, None, decides_convergence=True)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment