Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
nlp
ahisto-modules
Named Entity Recognition Experiments
Commits
0b717e6d
Commit
0b717e6d
authored
Sep 23, 2022
by
Vít Novotný
Browse files
Add `BIOTokenPunctuationStrippingClassification`
parent
224e461d
Changes
4
Hide whitespace changes
Inline
Side-by-side
ahisto_named_entity_search/recognition/__init__.py
View file @
0b717e6d
...
...
@@ -6,6 +6,10 @@ from .model import (
NerModel
)
from
.objective
import
(
BIOTokenPunctuationStrippingClassification
,
)
from
.schedule
import
(
get_schedule
,
ScheduleName
,
...
...
@@ -15,6 +19,7 @@ from .schedule import (
__all__
=
[
'AggregateMeanFScoreEvaluator'
,
'BIOTokenPunctuationStrippingClassification'
,
'get_schedule'
,
'NerModel'
,
'ScheduleName'
,
...
...
ahisto_named_entity_search/recognition/evaluator.py
View file @
0b717e6d
...
...
@@ -11,6 +11,7 @@ from ..config import CONFIG as _CONFIG
CategoryName
=
str
CategoryNames
=
Iterable
[
CategoryName
]
GroupName
=
str
GroupNames
=
Iterable
[
GroupName
]
Category
=
int
...
...
@@ -39,13 +40,13 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
super
().
__init__
(
*
args
,
**
kwargs
)
def
__call__
(
self
,
model
:
torch
.
nn
.
Module
,
tokenizer
:
PreTrainedTokenizer
,
dataset
:
AdaptationDataset
)
->
FScore
:
dataset
:
AdaptationDataset
,
ignored_index
:
Category
=
-
100
)
->
FScore
:
expected_labels
,
actual_labels
=
self
.
_collect_token_predictions
(
model
,
dataset
)
mean_f_score
,
total_number_of_samples
=
0
,
0
for
group_name
in
self
.
group_names
:
number_of_samples
,
f_score
=
self
.
get_f_score
(
self
.
GROUPS
[
group_name
],
expected_labels
,
actual_labels
)
self
.
GROUPS
[
group_name
],
expected_labels
,
actual_labels
,
ignored_index
)
mean_f_score
+=
number_of_samples
*
f_score
total_number_of_samples
+=
number_of_samples
if
total_number_of_samples
>
0
:
...
...
@@ -54,7 +55,7 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
return
mean_f_score
def
get_f_score
(
self
,
group
:
Group
,
expected_labels
:
List
[
Category
],
actual_labels
:
List
[
Category
])
->
Tuple
[
int
,
FScore
]:
actual_labels
:
List
[
Category
]
,
ignored_index
:
Category
)
->
Tuple
[
int
,
FScore
]:
expected_categories
:
Set
[
Category
]
=
{
self
.
category_map
[
category
]
for
category
...
...
@@ -64,6 +65,8 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
true_positives
,
false_positives
,
false_negatives
=
0
,
0
,
0
for
expected_label
,
actual_label
in
zip_equal
(
expected_labels
,
actual_labels
):
if
expected_label
==
ignored_index
:
continue
if
expected_label
in
expected_categories
and
actual_label
in
expected_categories
:
true_positives
+=
1
elif
expected_label
not
in
expected_categories
and
actual_label
in
expected_categories
:
...
...
ahisto_named_entity_search/recognition/model.py
View file @
0b717e6d
...
...
@@ -6,7 +6,6 @@ from typing import Tuple, List, Optional, Iterable, Dict
import
comet_ml
# noqa: F401
from
adaptor.adapter
import
Adapter
from
adaptor.objectives.classification
import
TokenClassification
from
adaptor.objectives.MLM
import
MaskedLanguageModeling
from
adaptor.lang_module
import
LangModule
from
adaptor.utils
import
StoppingStrategy
,
AdaptationArguments
...
...
@@ -17,6 +16,7 @@ from ..document import Document, Sentence
from
..search
import
TaggedSentence
,
NerTags
from
.schedule
import
ScheduleName
,
get_schedule
from
.evaluator
import
AggregateMeanFScoreEvaluator
,
FScore
,
CategoryMap
,
CategoryName
from
.objective
import
BIOTokenPunctuationStrippingClassification
LOGGER
=
getLogger
(
__name__
)
...
...
@@ -65,13 +65,14 @@ class NerModel:
ner_test_texts
,
ner_test_labels
=
load_ner_dataset
(
test_tagged_sentence_basename
)
ner_evaluators
=
list
(
get_evaluators
(
self
.
labels
))
ner_objective
=
TokenClassification
(
lang_module
,
batch_size
=
1
,
texts_or_path
=
[
'placeholder text'
],
labels_or_path
=
[
' '
.
join
(
self
.
labels
)],
val_texts_or_path
=
ner_test_texts
,
val_labels_or_path
=
ner_test_labels
,
val_evaluators
=
ner_evaluators
)
ner_objective
=
BIOTokenPunctuationStrippingClassification
(
lang_module
,
batch_size
=
1
,
texts_or_path
=
[
'placeholder text'
],
labels_or_path
=
[
' '
.
join
(
self
.
labels
)],
val_texts_or_path
=
ner_test_texts
,
val_labels_or_path
=
ner_test_labels
,
val_evaluators
=
ner_evaluators
)
adaptation_arguments
=
AdaptationArguments
(
output_dir
=
'.'
,
stopping_strategy
=
StoppingStrategy
.
FIRST_OBJECTIVE_CONVERGED
,
...
...
@@ -116,13 +117,14 @@ class NerModel:
ner_validation_labels
=
ner_validation_labels
[:
cls
.
NUM_VALIDATION_SAMPLES
]
ner_evaluators
=
list
(
get_evaluators
(
cls
.
LABELS
))
ner_objective
=
TokenClassification
(
lang_module
,
batch_size
=
cls
.
BATCH_SIZE
,
texts_or_path
=
ner_training_texts
,
labels_or_path
=
ner_training_labels
,
val_texts_or_path
=
ner_validation_texts
,
val_labels_or_path
=
ner_validation_labels
,
val_evaluators
=
ner_evaluators
)
ner_objective
=
BIOTokenPunctuationStrippingClassification
(
lang_module
,
batch_size
=
cls
.
BATCH_SIZE
,
texts_or_path
=
ner_training_texts
,
labels_or_path
=
ner_training_labels
,
val_texts_or_path
=
ner_validation_texts
,
val_labels_or_path
=
ner_validation_labels
,
val_evaluators
=
ner_evaluators
)
# Train MLM and NER in parallel until convergence on validation
model_checkpoint_pathname
=
cls
.
ROOT_PATH
/
model_checkpoint_basename
...
...
@@ -159,7 +161,7 @@ class NerModel:
@
classmethod
def
load
(
cls
,
model_basename
:
str
)
->
'NerModel'
:
model_pathname
=
cls
.
ROOT_PATH
/
model_basename
model_pathname
=
model_pathname
/
'TokenClassification'
model_pathname
=
model_pathname
/
'
BIO
Token
PunctuationStripping
Classification'
model_name_or_basename
=
str
(
model_pathname
)
return
cls
(
model_name_or_basename
)
...
...
ahisto_named_entity_search/recognition/objective.py
0 → 100644
View file @
0b717e6d
from
itertools
import
islice
from
typing
import
Dict
,
Iterable
from
adaptor.objectives.classification
import
TokenClassification
import
regex
import
torch
from
transformers
import
DataCollatorForTokenClassification
from
.evaluator
import
CategoryName
,
CategoryNames
,
Category
Token
=
str
class
BIOTokenPunctuationStrippingClassification
(
TokenClassification
):
def
_wordpiece_token_label_alignment
(
self
,
texts
:
CategoryNames
,
labels
:
CategoryNames
,
label_all_tokens
:
bool
=
True
,
ignore_label_idx
:
Category
=
-
100
)
->
Iterable
[
Dict
[
str
,
torch
.
LongTensor
]]:
texts
,
labels
=
list
(
texts
),
list
(
labels
)
collator
=
DataCollatorForTokenClassification
(
self
.
tokenizer
,
pad_to_multiple_of
=
8
)
batch_features
=
[]
# special tokens identification: general heuristic
ids1
=
self
.
tokenizer
(
"X"
).
input_ids
ids2
=
self
.
tokenizer
(
"Y"
).
input_ids
special_bos_tokens
=
[]
for
i
in
range
(
len
(
ids1
)):
if
ids1
[
i
]
==
ids2
[
i
]:
special_bos_tokens
.
append
(
ids1
[
i
])
else
:
break
special_eos_tokens
=
[]
for
i
in
range
(
1
,
len
(
ids1
)):
if
ids1
[
-
i
]
==
ids2
[
-
i
]:
special_eos_tokens
.
append
(
ids1
[
-
i
])
else
:
break
special_eos_tokens
=
list
(
reversed
(
special_eos_tokens
))
# per-sample iteration
for
text
,
text_labels
in
zip
(
texts
,
labels
):
tokens
=
text
.
split
()
labels
=
text_labels
.
split
()
assert
len
(
tokens
)
==
len
(
labels
),
\
"A number of tokens in the first line is different than a number of labels. "
\
"Text: %s
\n
Labels: %s"
%
(
text
,
text_labels
)
tokens_ids
=
self
.
tokenizer
(
tokens
,
truncation
=
True
,
add_special_tokens
=
False
).
input_ids
wpiece_ids
=
special_bos_tokens
.
copy
()
# labels of BoS and EoS are always "other"
out_label_ids
=
[
ignore_label_idx
]
*
len
(
special_bos_tokens
)
def
get_label_type
(
label
:
CategoryName
)
->
CategoryName
:
if
label
.
startswith
(
'B-'
)
or
label
.
startswith
(
'I-'
):
label_type
=
label
[
2
:]
else
:
label_type
=
label
return
label_type
def
get_label_ids
(
head_label
:
CategoryName
)
->
Iterable
[
Category
]:
head_label_id
=
self
.
labels_map
[
head_label
]
tail_label
=
f
'I-
{
get_label_type
(
head_label
)
}
'
if
head_label
.
startswith
(
'B-'
)
else
head_label
tail_label_id
=
self
.
labels_map
[
tail_label
]
yield
head_label_id
while
True
:
yield
tail_label_id
def
is_punctuation
(
token
:
Token
)
->
bool
:
punctuation_regex
=
r
'^\W*$'
punctuation_match
=
regex
.
match
(
punctuation_regex
,
token
)
is_punctuation
=
punctuation_match
is
not
None
return
is_punctuation
def
strip_trailing_punctuation
(
label_ids
:
Iterable
[
Category
],
tokens
:
Iterable
[
Token
])
->
Iterable
[
Category
]:
label_ids
,
tokens
=
list
(
label_ids
),
list
(
tokens
)
assert
len
(
label_ids
)
==
len
(
tokens
)
for
token_index
,
token
in
reversed
(
list
(
enumerate
(
tokens
))):
if
is_punctuation
(
token
):
label_ids
[
token_index
]
=
self
.
labels_map
[
'O'
]
else
:
break
return
label_ids
for
label_index
,
(
token_ids
,
label
)
in
enumerate
(
zip
(
tokens_ids
,
labels
)):
# chain the wordpieces without the special symbols for each token
wpiece_ids
.
extend
(
token_ids
)
if
label_all_tokens
:
# label all wordpieces
label_ids
=
get_label_ids
(
label
)
label_ids
=
islice
(
label_ids
,
len
(
token_ids
))
if
label
!=
'O'
:
if
label_index
+
1
>=
len
(
labels
):
should_strip_punctuation
=
False
else
:
label_type
=
get_label_type
(
label
)
next_label
=
labels
[
label_index
+
1
]
next_label_type
=
get_label_type
(
next_label
)
should_strip_punctuation
=
label_type
!=
next_label_type
if
should_strip_punctuation
:
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
label_ids
=
strip_trailing_punctuation
(
label_ids
,
tokens
)
out_label_ids
.
extend
(
label_ids
)
else
:
# label only the first wordpiece
out_label_ids
.
append
(
self
.
labels_map
[
label
])
# ignore the predictions over other token's wordpieces from the loss
out_label_ids
.
extend
([
ignore_label_idx
]
*
(
len
(
token_ids
)
-
1
))
out_label_ids
.
extend
([
ignore_label_idx
]
*
len
(
special_eos_tokens
))
wpiece_ids
.
extend
(
special_eos_tokens
.
copy
())
assert
len
(
out_label_ids
)
==
len
(
wpiece_ids
),
\
"We found misaligned labels in sample: '%s'"
%
text
if
self
.
tokenizer
.
model_max_length
is
None
:
truncated_size
=
len
(
out_label_ids
)
else
:
truncated_size
=
min
(
self
.
tokenizer
.
model_max_length
,
len
(
out_label_ids
))
batch_features
.
append
({
"input_ids"
:
wpiece_ids
[:
truncated_size
],
"attention_mask"
:
[
1
]
*
truncated_size
,
"labels"
:
out_label_ids
[:
truncated_size
]})
# maybe yield a batch
if
len
(
batch_features
)
==
self
.
batch_size
:
yield
collator
(
batch_features
)
batch_features
=
[]
if
batch_features
:
yield
collator
(
batch_features
)
# check that the number of outputs of the selected compatible head matches the just-parsed data set
num_outputs
=
list
(
self
.
compatible_head_model
.
parameters
())[
-
1
].
shape
[
0
]
num_labels
=
len
(
self
.
labels_map
)
assert
num_outputs
==
num_labels
,
"A number of the outputs for the selected %s head (%s) "
\
"does not match a number of token labels (%s)"
\
%
(
self
.
compatible_head
,
num_outputs
,
num_labels
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment