Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
nlp
ahisto-modules
Named Entity Recognition Experiments
Commits
ee1164d0
Commit
ee1164d0
authored
Sep 19, 2022
by
Vít Novotný
Browse files
Use micro-averaged `PER` and `LOC` F1-score as default objective
parent
51b285b6
Pipeline
#147350
passed with stage
in 8 minutes and 59 seconds
Changes
4
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
03_train_ner_models.ipynb
View file @
ee1164d0
...
...
@@ -237,7 +237,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Loading documents: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 268669/268669 [00:0
6
<00:00, 4
4308.95
it/s]\n"
"Loading documents: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 268669/268669 [00:0
5
<00:00, 4
7033.84
it/s]\n"
]
}
],
...
...
@@ -462,7 +462,7 @@
},
{
"cell_type": "code",
"execution_count": 2
1
,
"execution_count": 2
2
,
"id": "694daad3-2b04-4e3f-8bfb-bb3fe0c87dd3",
"metadata": {},
"outputs": [],
...
...
@@ -473,7 +473,7 @@
},
{
"cell_type": "code",
"execution_count": 2
2
,
"execution_count": 2
3
,
"id": "fed4d0a4-5bc4-4af2-8e1b-c5a8a6b61c52",
"metadata": {},
"outputs": [],
...
...
@@ -491,7 +491,7 @@
},
{
"cell_type": "code",
"execution_count": 2
3
,
"execution_count": 2
4
,
"id": "38efa732-8afd-4798-809a-ca828a8b960c",
"metadata": {},
"outputs": [],
...
...
@@ -501,7 +501,7 @@
},
{
"cell_type": "code",
"execution_count": 2
4
,
"execution_count": 2
5
,
"id": "1c9d3bab-53de-4c39-8e4d-cd2978f00925",
"metadata": {},
"outputs": [],
...
...
@@ -513,7 +513,7 @@
},
{
"cell_type": "code",
"execution_count": 2
5
,
"execution_count": 2
6
,
"id": "f47f62ca-1164-45b5-94b9-ff082787e8a9",
"metadata": {},
"outputs": [],
...
...
@@ -526,7 +526,7 @@
},
{
"cell_type": "code",
"execution_count": 2
6
,
"execution_count": 2
7
,
"id": "053ff5dd-775c-431d-a988-18bf3c4f4f6d",
"metadata": {},
"outputs": [
...
...
@@ -551,165 +551,165 @@
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PER</th>\n",
" <th>O</th>\n",
" <th>LOC</th>\n",
" <th>all</th>\n",
" <th>O</th>\n",
" <th>PER</th>\n",
" <th>PER+LOC</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>model_ner_manatee_all_only-relevant_fine-tuning</th>\n",
" <td>49.63405%</td>\n",
" <td>96.74517%</td>\n",
" <td>39.86766%</td>\n",
" <td>94.77053%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_all_parallel</th>\n",
" <td>43.69512%</td>\n",
" <td>96.65764%</td>\n",
" <td>39.60441%</td>\n",
" <td>94.52101%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_only-relevant_fine-tuning</th>\n",
" <td>49.97944%</td>\n",
" <td>96.66777%</td>\n",
" <td>41.04577%</td>\n",
" <td>94.46477%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_only-relevant_fine-tuning</th>\n",
" <td>34.07244%</td>\n",
" <td>96.38580%</td>\n",
" <td>36.68019%</td>\n",
" <td>93.57176%</td>\n",
" <td>96.66777%</td>\n",
" <td>49.97944%</td>\n",
" <td>44.30760%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_
fuzzy-regex_non-crossing
_only-relevant_fine-tuning</th>\n",
" <td>3
4.74969
%</td>\n",
" <td>96.
34
17
4
%</td>\n",
" <td>
36.64832
%</td>\n",
" <td>
9
3.
41004
%</td>\n",
" <th>model_ner_
manatee_all
_only-relevant_fine-tuning</th>\n",
" <td>3
9.86766
%</td>\n",
" <td>96.
745
17%</td>\n",
" <td>
49.63405
%</td>\n",
" <td>
4
3.
94366
%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_only-relevant_parallel</th>\n",
" <td>51.62152%</td>\n",
" <td>96.17568%</td>\n",
" <td>38.21471%</td>\n",
" <td>93.32879%</td>\n",
" <td>96.17568%</td>\n",
" <td>51.62152%</td>\n",
" <td>43.57605%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_only-relevant_parallel</th>\n",
" <td>38.21755%</td>\n",
" <td>96.10254%</td>\n",
" <td>49.71262%</td>\n",
" <td>42.94907%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_all_parallel</th>\n",
" <td>39.60441%</td>\n",
" <td>96.65764%</td>\n",
" <td>43.69512%</td>\n",
" <td>41.58149%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_all_parallel</th>\n",
" <td>44.26929%</td>\n",
" <td>96.20140%</td>\n",
" <td>34.92015%</td>\n",
" <td>93.31223%</td>\n",
" <td>96.20140%</td>\n",
" <td>44.26929%</td>\n",
" <td>38.61073%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_only-relevant_parallel</th>\n",
" <td>49.71262%</td>\n",
" <td>96.10254%</td>\n",
" <td>38.21755%</td>\n",
" <td>93.23893%</td>\n",
" <th>model_ner_fuzzy-regex_non-crossing_only-relevant_fine-tuning</th>\n",
" <td>36.64832%</td>\n",
" <td>96.34174%</td>\n",
" <td>34.74969%</td>\n",
" <td>35.69222%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_only-relevant_fine-tuning</th>\n",
" <td>36.68019%</td>\n",
" <td>96.38580%</td>\n",
" <td>34.07244%</td>\n",
" <td>35.37521%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_all_parallel</th>\n",
" <td>33.04513%</td>\n",
" <td>95.95721%</td>\n",
" <td>37.66807%</td>\n",
" <td>92.63503%</td>\n",
" <td>95.95721%</td>\n",
" <td>33.04513%</td>\n",
" <td>34.98948%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_non-crossing_all_parallel</th>\n",
" <td>33.17825%</td>\n",
" <td>95.71407%</td>\n",
" <td>32.79457%</td>\n",
" <td>92.03093%</td>\n",
" <td>95.71407%</td>\n",
" <td>33.17825%</td>\n",
" <td>32.99174%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_only-relevant_parallel</th>\n",
" <td>31.30961%</td>\n",
" <td>95.59693%</td>\n",
" <td>32.23955%</td>\n",
" <td>91.78359%</td>\n",
" <td>95.59693%</td>\n",
" <td>31.30961%</td>\n",
" <td>31.77318%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_non-crossing_only-relevant_parallel</th>\n",
" <td>32.51893%</td>\n",
" <td>95.45893%</td>\n",
" <td>30.89947%</td>\n",
" <td>91.47703%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_all_fine-tuning</th>\n",
" <td>2.17360%</td>\n",
" <td>42.86388%</td>\n",
" <td>3.80143%</td>\n",
" <td>25.11350%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_all_fine-tuning</th>\n",
" <td>2.34774%</td>\n",
" <td>23.59883%</td>\n",
" <td>2.75918%</td>\n",
" <td>13.38802%</td>\n",
" <td>95.45893%</td>\n",
" <td>32.51893%</td>\n",
" <td>31.66439%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Babelscape/wikineural-multilingual-ner baseline</th>\n",
" <td>7.35338%</td>\n",
" <td>13.35824%</td>\n",
" <td>2.84895%</td>\n",
" <td>8.07667%</td>\n",
" <td>13.35824%</td>\n",
" <td>7.35338%</td>\n",
" <td>3.57886%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_fuzzy-regex_all_all_fine-tuning</th>\n",
" <td>2.38798%</td>\n",
" <td>7.33972%</td>\n",
" <td>3.32850%</td>\n",
" <td>4.96872%</td>\n",
" <td>7.33972%</td>\n",
" <td>2.38798%</td>\n",
" <td>2.68821%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_all_all_fine-tuning</th>\n",
" <td>3.80143%</td>\n",
" <td>42.86388%</td>\n",
" <td>2.17360%</td>\n",
" <td>2.60517%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>model_ner_manatee_non-crossing_all_fine-tuning</th>\n",
" <td>2.75918%</td>\n",
" <td>23.59883%</td>\n",
" <td>2.34774%</td>\n",
" <td>2.51480%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"
PER
O \\\n",
"model_ner_manatee_
all
_only-relevant_fi
ne-tuning 49.63405
% 96.
7451
7% \n",
"model_ner_manatee_all_
all_parallel 43.69512% 96.65764
% \n",
"model_ner_manatee_
non-crossing
_only-relevant_
fi... 49.97944% 96.66777
% \n",
"model_ner_
fuzzy-regex_all_only-relevant_fine-tu... 34.07244% 96.38580
% \n",
"model_ner_
fuzzy-regex_non-crossing_only-relevan... 34.74969
% 96.
3417
4% \n",
"model_ner_manatee_
all_only-relevant
_parallel
51.62
15
2
% 96.
17568
% \n",
"model_ner_
manatee
_non-crossing_
all_parallel 44.26929% 96.20140
% \n",
"model_ner_
manatee_non-crossing_only-relevant_pa... 49.71262% 96.10254
% \n",
"model_ner_fuzzy-regex_all_all_parallel 3
3.04513
% 95.95721% \n",
"model_ner_fuzzy-regex_non-crossing_all_parallel 3
3.17825
% 95.71407% \n",
"model_ner_fuzzy-regex_all_only-relevant_parallel 3
1.30961
% 95.59693% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 3
2.51893
% 95.45893% \n",
"
model_ner_manatee_all_all_fine-tuning 2.17360% 42.86388
% \n",
"model_ner_
manatee_non-crossing
_all_fine-tuning
2.34774% 23.59883
% \n",
"
Babelscape/wikineural-multilingual-ner basel
in
e
7.35338% 13.35824
% \n",
"model_ner_
fuzzy-regex_all
_all_fine-tuning
2.38798% 7.33972
% \n",
"
LOC
O \\\n",
"model_ner_manatee_
non-crossing
_only-relevant_fi
... 41.04577
% 96.
6677
7% \n",
"model_ner_manatee_all_
only-relevant_fine-tuning 39.86766% 96.74517
% \n",
"model_ner_manatee_
all
_only-relevant_
parallel 38.21471% 96.17568
% \n",
"model_ner_
manatee_non-crossing_only-relevant_pa... 38.21755% 96.10254
% \n",
"model_ner_
manatee_all_all_parallel 39.60441
% 96.
6576
4% \n",
"model_ner_manatee_
non-crossing_all
_parallel
34.920
15% 96.
20140
% \n",
"model_ner_
fuzzy-regex
_non-crossing_
only-relevan... 36.64832% 96.34174
% \n",
"model_ner_
fuzzy-regex_all_only-relevant_fine-tu... 36.68019% 96.38580
% \n",
"model_ner_fuzzy-regex_all_all_parallel 3
7.66807
% 95.95721% \n",
"model_ner_fuzzy-regex_non-crossing_all_parallel 3
2.79457
% 95.71407% \n",
"model_ner_fuzzy-regex_all_only-relevant_parallel 3
2.23955
% 95.59693% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 3
0.89947
% 95.45893% \n",
"
Babelscape/wikineural-multilingual-ner baseline 2.84895% 13.35824
% \n",
"model_ner_
fuzzy-regex
_all_
all_
fine-tuning
3.32850% 7.33972
% \n",
"
model_ner_manatee_all_all_fine-tun
in
g
3.80143% 42.86388
% \n",
"model_ner_
manatee_non-crossing
_all_fine-tuning
2.75918% 23.59883
% \n",
"\n",
"
LOC all
\n",
"model_ner_manatee_
all
_only-relevant_fi
ne-tuning 39.86766
%
9
4.
77053
% \n",
"model_ner_manatee_all_
all_parallel
3
9.6
0441% 94.52101
% \n",
"model_ner_manatee_
non-crossing
_only-relevant_
fi... 41.04577% 94.46477
% \n",
"model_ner_
fuzzy-regex_all_only-relevant_fine-tu... 36.68019% 93.57176
% \n",
"model_ner_
fuzzy-regex_non-crossing_only-relevan... 36.64832% 93.41004
% \n",
"model_ner_manatee_
all_only-relevant_parallel 38.21471% 93.32879
% \n",
"model_ner_
manatee
_non-crossing_
all_parallel 34.92015% 93.31
22
3
% \n",
"model_ner_
manatee_non-crossing_only-relevant_pa... 38.21755% 93.23893
% \n",
"model_ner_fuzzy-regex_all_all_parallel 3
7.66807% 92.63503
% \n",
"model_ner_fuzzy-regex_non-crossing_all_parallel 3
2.79457
%
9
2.
03093
% \n",
"model_ner_fuzzy-regex_all_only-relevant_parallel 3
2.23955
%
9
1.7
8359
% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 3
0.89947
%
9
1.
47703
% \n",
"
model_ner_manatee_all_all_fine-tuning 3.80143% 25.11350
% \n",
"model_ner_
manatee_non-crossing
_all_fine-tuning
2.75918% 13.38802
% \n",
"
Babelscape/wikineural-multilingual-ner baseline 2.84895% 8.0766
7% \n",
"model_ner_
fuzzy-regex_all
_all_fine-tuning
3.32850% 4.96872
% "
"
PER PER+LOC
\n",
"model_ner_manatee_
non-crossing
_only-relevant_fi
... 49.97944
%
4
4.
30760
% \n",
"model_ner_manatee_all_
only-relevant_fine-tuning
4
9.6
3405% 43.94366
% \n",
"model_ner_manatee_
all
_only-relevant_
parallel 51.62152% 43.57605
% \n",
"model_ner_
manatee_non-crossing_only-relevant_pa... 49.71262% 42.94907
% \n",
"model_ner_
manatee_all_all_parallel 43.69512% 41.58149
% \n",
"model_ner_manatee_
non-crossing_all_parallel 44.26929% 38.61073
% \n",
"model_ner_
fuzzy-regex
_non-crossing_
only-relevan... 34.74969% 35.692
22% \n",
"model_ner_
fuzzy-regex_all_only-relevant_fine-tu... 34.07244% 35.37521
% \n",
"model_ner_fuzzy-regex_all_all_parallel 3
3.04513% 34.98948
% \n",
"model_ner_fuzzy-regex_non-crossing_all_parallel 3
3.17825
%
3
2.
99174
% \n",
"model_ner_fuzzy-regex_all_only-relevant_parallel 3
1.30961
%
3
1.7
7318
% \n",
"model_ner_fuzzy-regex_non-crossing_only-relevan... 3
2.51893
%
3
1.
66439
% \n",
"
Babelscape/wikineural-multilingual-ner baseline 7.35338% 3.57886
% \n",
"model_ner_
fuzzy-regex_all
_all_fine-tuning
2.38798% 2.68821
% \n",
"
model_ner_manatee_all_all_fine-tuning 2.17360% 2.6051
7% \n",
"model_ner_
manatee_non-crossing
_all_fine-tuning
2.34774% 2.51480
% "
]
},
"metadata": {},
...
...
@@ -718,7 +718,7 @@
],
"source": [
"with pd.option_context('display.float_format', lambda mean_f_score: f'{100.0 * mean_f_score:.5f}%'):\n",
" display(f_scores_df.sort_values(by=['
all
'], ascending=False))"
" display(f_scores_df.sort_values(by=['
PER+LOC
'], ascending=False))"
]
},
{
...
...
@@ -731,7 +731,7 @@
},
{
"cell_type": "code",
"execution_count": 2
7
,
"execution_count": 2
8
,
"id": "f7975e45-ba27-45b4-9b61-1a9c119d434d",
"metadata": {},
"outputs": [
...
...
@@ -739,12 +739,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"/nlp/projekty/ahisto/public_html/named-entity-search/results/model_ner_manatee_
all
_only-relevant_fine-tuning/TokenClassification\n"
"/nlp/projekty/ahisto/public_html/named-entity-search/results/model_ner_manatee_
non-crossing
_only-relevant_fine-tuning/TokenClassification\n"
]
}
],
"source": [
"all_evaluator_index, = [index for index, evaluator in enumerate(evaluators) if str(evaluator) == '
all
']\n",
"all_evaluator_index, = [index for index, evaluator in enumerate(evaluators) if str(evaluator) == '
PER+LOC
']\n",
"best_model, _ = max(f_scores_dict.items(), key=lambda x: (x[1][all_evaluator_index], x[0]))\n",
"print(best_model)"
]
...
...
@@ -761,7 +761,7 @@
},
{
"cell_type": "code",
"execution_count": 2
8
,
"execution_count": 2
9
,
"id": "fffd0beb-ac50-4c81-9eb3-cc225214ff63",
"metadata": {},
"outputs": [],
...
...
@@ -777,7 +777,7 @@
},
{
"cell_type": "code",
"execution_count":
29
,
"execution_count":
30
,
"id": "b83d2a1a-4d8f-400a-a8e8-cba43fe41a83",
"metadata": {},
"outputs": [
...
...
@@ -805,7 +805,7 @@
},
{
"cell_type": "code",
"execution_count": 3
0
,
"execution_count": 3
1
,
"id": "de04e6a9-33e5-4e85-9cfc-f0f5bc344677",
"metadata": {},
"outputs": [],
...
...
@@ -815,7 +815,7 @@
},
{
"cell_type": "code",
"execution_count": 3
1
,
"execution_count": 3
2
,
"id": "d11a74e1-f3c8-416c-897e-e921a69dc661",
"metadata": {},
"outputs": [],
...
...
@@ -829,7 +829,7 @@
},
{
"cell_type": "code",
"execution_count": 3
2
,
"execution_count": 3
3
,
"id": "675dd306-50b0-4245-9904-effff2432921",
"metadata": {},
"outputs": [
...
...
@@ -862,7 +862,7 @@
},
{
"cell_type": "code",
"execution_count": 3
3
,
"execution_count": 3
4
,
"id": "2b1c528d-b4ec-4e96-8ca7-ff05fd244d80",
"metadata": {},
"outputs": [
...
...
@@ -890,7 +890,7 @@
"- I-LOC: ##de\n",
"- I-LOC: ##ch\n"
]
}
,
}
],
"source": [
"tag_sentence(baseline_model, example_sentence)"
...
...
ahisto_named_entity_search/default.ini
View file @
ee1164d0
...
...
@@ -106,3 +106,6 @@ maximum_number_of_training_epochs_per_objective = 1
[recognition.FineTuningSchedule]
maximum_number_of_training_epochs_per_objective
=
5
[recognition.AggregateMeanFScoreEvaluator]
default_group_names
=
PER+LOC
ahisto_named_entity_search/recognition/evaluator.py
View file @
ee1164d0
from
typing
import
Dict
,
Optional
,
Set
,
List
,
Tuple
from
typing
import
Dict
,
Iterable
,
Optional
,
Set
,
List
,
Tuple
from
functools
import
total_ordering
from
adaptor.evaluators.token_classification
import
TokenClassificationEvaluator
from
adaptor.utils
import
AdaptationDataset
from
more_itertools
import
zip_equal
import
torch
from
transformers
import
PreTrainedTokenizer
from
adaptor.evaluators.token_classification
import
TokenClassificationEvaluator
from
adaptor.utils
import
AdaptationDataset
from
..config
import
CONFIG
as
_CONFIG
CategoryName
=
str
GroupName
=
str
GroupNames
=
Iterable
[
GroupName
]
Category
=
int
Group
=
Set
[
CategoryName
]
CategoryMap
=
Dict
[
CategoryName
,
Category
]
...
...
@@ -21,34 +23,32 @@ FScore = float
@
total_ordering
class
AggregateMeanFScoreEvaluator
(
TokenClassificationEvaluator
):
CONFIG
=
_CONFIG
[
'recognition.AggregateMeanFScoreEvaluator'
]
DEFAULT_GROUP_NAMES
=
tuple
(
CONFIG
[
'default_group_names'
].
split
(
'+'
))
GROUPS
:
GroupMap
=
{
'PER'
:
{
'B-PER'
,
'I-PER'
,
'B-ORG'
,
'I-ORG'
},
'LOC'
:
{
'B-LOC'
,
'I-LOC'
,
'B-ORG'
,
'I-ORG'
},
'O'
:
{
'O'
,
'B-MISC'
,
'I-MISC'
},
}
def
__init__
(
self
,
category_map
:
CategoryMap
,
group_name
:
Optional
[
GroupName
],
def
__init__
(
self
,
category_map
:
CategoryMap
,
group_name
s
:
Optional
[
GroupName
s
],
*
args
,
**
kwargs
):
self
.
category_map
=
category_map
self
.
group_name
=
group_name
self
.
group_name
s
=
self
.
DEFAULT_GROUP_NAMES
if
group_names
is
None
else
tuple
(
group_names
)
super
().
__init__
(
*
args
,
**
kwargs
)
def
__call__
(
self
,
model
:
torch
.
nn
.
Module
,
tokenizer
:
PreTrainedTokenizer
,
dataset
:
AdaptationDataset
)
->
FScore
:
expected_labels
,
actual_labels
=
self
.
_collect_token_predictions
(
model
,
dataset
)
if
self
.
group_name
is
None
:
mean_f_score
,
total_number_of_samples
=
0
,
0
for
group_name
in
self
.
__class__
.
get_all_group_names
():
number_of_samples
,
f_score
=
self
.
get_f_score
(
self
.
GROUPS
[
group_name
],
expected_labels
,
actual_labels
)
mean_f_score
+=
number_of_samples
*
f_score
total_number_of_samples
+=
number_of_samples
if
total_number_of_samples
>
0
:
_
,
mean_f_score
/=
total_number_of_samples
else
:
group
=
self
.
GROUPS
[
self
.
group_name
]
mean_f_score
=
self
.
get_f_score
(
group
,
expected_labels
,
actual_labels
)
mean_f_score
,
total_number_of_samples
=
0
,
0
for
group_name
in
self
.
group_names
:
number_of_samples
,
f_score
=
self
.
get_f_score
(
self
.
GROUPS
[
group_name
],
expected_labels
,
actual_labels
)
mean_f_score
+=
number_of_samples
*
f_score
total_number_of_samples
+=
number_of_samples
if
total_number_of_samples
>
0
:
mean_f_score
/=
total_number_of_samples
return
mean_f_score
...
...
@@ -74,28 +74,21 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
f_score
=
true_positives
/
(
true_positives
+
(
0.5
*
(
false_positives
+
false_negatives
)))
return
number_of_samples
,
f_score
@
classmethod
def
get_all_group_names
(
cls
)
->
Set
[
GroupName
]:
all_group_names
=
set
(
cls
.
GROUPS
.
keys
())
return
all_group_names
def
__hash__
(
self
):
return
hash
(
self
.
group_name
)
return
hash
(
self
.
group_name
s
)
def
__eq__
(
self
,
other
)
->
bool
:
if
isinstance
(
other
,
AggregateMeanFScoreEvaluator
):
return
self
.
group_name
==
other
.
group_name
return
self
.
group_name
s
==
other
.
group_name
s
return
NotImplemented
def
__lt__
(
self
,
other
)
->
bool
:
if
isinstance
(
other
,
AggregateMeanFScoreEvaluator
):
if
self
.
group_name
is
None
or
other
.
group_name
is
None
:
return
self
.
group_name
is
None
return
self
.
group_name
<
other
.
group_name
return
self
.
group_names
<
other
.
group_names
return
NotImplemented
def
__repr__
(
self
)
->
str
:
return
'+'
.
join
(
self
.
group_names
)
def
__str__
(
self
)
->
str
:
if
self
.
group_name
is
None
:
return
'all'
else
:
return
self
.
group_name
return
repr
(
self
)
ahisto_named_entity_search/recognition/model.py
View file @
ee1164d0
...
...
@@ -178,6 +178,6 @@ def get_evaluators(labels: Iterable[str]) -> Iterable[AggregateMeanFScoreEvaluat
for
category_index
,
category
in
enumerate
(
sorted
(
labels
))
}
for
group_name
in
AggregateMeanFScoreEvaluator
.
get_all_group_name
s
():
yield
AggregateMeanFScoreEvaluator
(
category_map
,
group_name
,
decides_convergence
=
False
)
for
group_name
in
sorted
(
AggregateMeanFScoreEvaluator
.
GROUPS
.
key
s
()
)
:
yield
AggregateMeanFScoreEvaluator
(
category_map
,
[
group_name
]
,
decides_convergence
=
False
)
yield
AggregateMeanFScoreEvaluator
(
category_map
,
None
,
decides_convergence
=
True
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment