Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
nlp
ahisto-modules
Named Entity Recognition Experiments
Commits
cd8adc77
Commit
cd8adc77
authored
Sep 25, 2022
by
Vít Novotný
Browse files
Rename CategoryName and Category to Label and LabelId
parent
f8df7140
Pipeline
#147689
passed with stage
in 10 minutes and 13 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
ahisto_named_entity_search/recognition/evaluator.py
View file @
cd8adc77
...
...
@@ -8,15 +8,15 @@ import torch
from
transformers
import
PreTrainedTokenizer
from
..config
import
CONFIG
as
_CONFIG
from
..search
import
BioNerTag
as
CategoryName
from
..search
import
BioNerTag
as
Label
CategoryName
s
=
Iterable
[
CategoryName
]
Label
s
=
Iterable
[
Label
]
GroupName
=
str
GroupNames
=
Iterable
[
GroupName
]
Category
=
int
Group
=
Set
[
CategoryName
]
Category
Map
=
Dict
[
CategoryName
,
Category
]
LabelId
=
int
Group
=
Set
[
Label
]
Label
Map
=
Dict
[
Label
,
LabelId
]
GroupMap
=
Dict
[
GroupName
,
Group
]
FScore
=
float
...
...
@@ -33,14 +33,14 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
'O'
:
{
'O'
,
'B-MISC'
,
'I-MISC'
},
}
def
__init__
(
self
,
category_map
:
Category
Map
,
group_names
:
Optional
[
GroupNames
],
def
__init__
(
self
,
label_map
:
Label
Map
,
group_names
:
Optional
[
GroupNames
],
*
args
,
**
kwargs
):
self
.
category_map
=
category
_map
self
.
label_map
=
label
_map
self
.
group_names
=
self
.
DEFAULT_GROUP_NAMES
if
group_names
is
None
else
tuple
(
group_names
)
super
().
__init__
(
*
args
,
**
kwargs
)
def
__call__
(
self
,
model
:
torch
.
nn
.
Module
,
tokenizer
:
PreTrainedTokenizer
,
dataset
:
AdaptationDataset
,
ignored_index
:
Category
=
-
100
)
->
FScore
:
dataset
:
AdaptationDataset
,
ignored_index
:
LabelId
=
-
100
)
->
FScore
:
expected_labels
,
actual_labels
=
self
.
_collect_token_predictions
(
model
,
dataset
)
mean_f_score
,
total_number_of_samples
=
0
,
0
...
...
@@ -54,13 +54,13 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
return
mean_f_score
def
get_f_score
(
self
,
group
:
Group
,
expected_labels
:
List
[
Category
],
actual_labels
:
List
[
Category
],
ignored_index
:
Category
)
->
Tuple
[
int
,
FScore
]:
expected_categories
:
Set
[
Category
]
=
{
self
.
category_map
[
category
]
for
category
def
get_f_score
(
self
,
group
:
Group
,
expected_labels
:
List
[
LabelId
],
actual_labels
:
List
[
LabelId
],
ignored_index
:
LabelId
)
->
Tuple
[
int
,
FScore
]:
expected_categories
:
Set
[
LabelId
]
=
{
self
.
label_map
[
line_id
]
for
line_id
in
group
if
category
in
self
.
category
_map
if
line_id
in
self
.
label
_map
}
true_positives
,
false_positives
,
false_negatives
=
0
,
0
,
0
...
...
ahisto_named_entity_search/recognition/model.py
View file @
cd8adc77
...
...
@@ -15,7 +15,7 @@ from ..config import CONFIG as _CONFIG
from
..document
import
Document
,
Sentence
from
..search
import
TaggedSentence
,
BioNerTags
from
.schedule
import
ScheduleName
,
get_schedule
from
.evaluator
import
AggregateMeanFScoreEvaluator
,
FScore
,
CategoryMap
,
CategoryName
from
.evaluator
import
AggregateMeanFScoreEvaluator
,
FScore
,
LabelMap
,
Label
from
.objective
import
BIOTokenPunctuationStrippingClassification
...
...
@@ -38,7 +38,7 @@ class NerModel:
SCHEDULE_NAME
=
CONFIG
[
'schedule'
]
NUM_VALIDATION_SAMPLES
=
CONFIG
.
getint
(
'number_of_validation_samples'
)
STOPPING_PATIENCE
=
CONFIG
.
getint
(
'stopping_patience'
)
LABELS
:
Iterable
[
CategoryName
]
=
(
'B-PER'
,
'I-PER'
,
'B-LOC'
,
'I-LOC'
,
'O'
)
LABELS
:
Iterable
[
Label
]
=
(
'B-PER'
,
'I-PER'
,
'B-LOC'
,
'I-LOC'
,
'O'
)
def
__init__
(
self
,
model_name_or_basename
:
str
,
labels
:
Iterable
[
str
]
=
LABELS
):
self
.
model_name_or_basename
=
model_name_or_basename
...
...
@@ -175,11 +175,11 @@ def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], Lis
def
get_evaluators
(
labels
:
Iterable
[
str
])
->
Iterable
[
AggregateMeanFScoreEvaluator
]:
category_map
:
Category
Map
=
{
category
:
category
_index
for
category_index
,
category
label_map
:
Label
Map
=
{
line_id
:
line_id
_index
for
line_id_index
,
line_id
in
enumerate
(
sorted
(
labels
))
}
for
group_name
in
sorted
(
AggregateMeanFScoreEvaluator
.
GROUPS
.
keys
()):
yield
AggregateMeanFScoreEvaluator
(
category
_map
,
[
group_name
],
decides_convergence
=
False
)
yield
AggregateMeanFScoreEvaluator
(
category
_map
,
None
,
decides_convergence
=
True
)
yield
AggregateMeanFScoreEvaluator
(
label
_map
,
[
group_name
],
decides_convergence
=
False
)
yield
AggregateMeanFScoreEvaluator
(
label
_map
,
None
,
decides_convergence
=
True
)
ahisto_named_entity_search/recognition/objective.py
View file @
cd8adc77
...
...
@@ -6,7 +6,7 @@ import regex
import
torch
from
transformers
import
DataCollatorForTokenClassification
from
.evaluator
import
CategoryName
,
CategoryNames
,
Category
from
.evaluator
import
Label
,
Labels
,
LabelId
Token
=
str
...
...
@@ -14,10 +14,10 @@ Token = str
class
BIOTokenPunctuationStrippingClassification
(
TokenClassification
):
def
_wordpiece_token_label_alignment
(
self
,
texts
:
CategoryName
s
,
labels
:
CategoryName
s
,
texts
:
Label
s
,
labels
:
Label
s
,
label_all_tokens
:
bool
=
True
,
ignore_label_idx
:
Category
=
-
100
)
->
Iterable
[
Dict
[
str
,
torch
.
LongTensor
]]:
ignore_label_idx
:
LabelId
=
-
100
)
->
Iterable
[
Dict
[
str
,
torch
.
LongTensor
]]:
texts
,
labels
=
list
(
texts
),
list
(
labels
)
collator
=
DataCollatorForTokenClassification
(
self
.
tokenizer
,
pad_to_multiple_of
=
8
)
...
...
@@ -58,14 +58,14 @@ class BIOTokenPunctuationStrippingClassification(TokenClassification):
# labels of BoS and EoS are always "other"
out_label_ids
=
[
ignore_label_idx
]
*
len
(
special_bos_tokens
)
def
get_label_type
(
label
:
CategoryName
)
->
CategoryName
:
def
get_label_type
(
label
:
Label
)
->
Label
:
if
label
.
startswith
(
'B-'
)
or
label
.
startswith
(
'I-'
):
label_type
=
label
[
2
:]
else
:
label_type
=
label
return
label_type
def
get_label_ids
(
head_label
:
CategoryName
)
->
Iterable
[
Category
]:
def
get_label_ids
(
head_label
:
Label
)
->
Iterable
[
LabelId
]:
head_label_id
=
self
.
labels_map
[
head_label
]
tail_label
=
f
'I-
{
get_label_type
(
head_label
)
}
'
if
head_label
.
startswith
(
'B-'
)
else
head_label
tail_label_id
=
self
.
labels_map
[
tail_label
]
...
...
@@ -80,8 +80,8 @@ class BIOTokenPunctuationStrippingClassification(TokenClassification):
is_punctuation
=
punctuation_match
is
not
None
return
is_punctuation
def
strip_trailing_punctuation
(
label_ids
:
Iterable
[
Category
],
tokens
:
Iterable
[
Token
])
->
Iterable
[
Category
]:
def
strip_trailing_punctuation
(
label_ids
:
Iterable
[
LabelId
],
tokens
:
Iterable
[
Token
])
->
Iterable
[
LabelId
]:
label_ids
,
tokens
=
list
(
label_ids
),
list
(
tokens
)
assert
len
(
label_ids
)
==
len
(
tokens
)
for
token_index
,
token
in
reversed
(
list
(
enumerate
(
tokens
))):
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment