Commit f8df7140 authored by Vít Novotný's avatar Vít Novotný
Browse files

Include BIO tags directly in produced datasets

parent c0f68d9c
Pipeline #147688 passed with stage
in 9 minutes and 57 seconds
from .entity import (
Entity,
load_entities,
NerTag,
Person,
Place,
Regest,
......@@ -27,6 +28,7 @@ __all__ = [
'Entity',
'EntityType',
'load_entities',
'NerTag',
'Patch',
'PatchConfirmatory',
'Patchset',
......
......@@ -23,6 +23,9 @@ LOGGER = getLogger(__name__)
COMPOSITE_ENTITY = re.compile(r'(?P<before_parens>.*)\((?P<inside_parens>[^)]*)\)\s*')
NerTag = str
@total_ordering
class Regest:
ROOT_PATH = Path(CONFIG['root_path'])
......@@ -109,7 +112,7 @@ class Entity(ABC):
@property
@abstractmethod
def ner_tag(self) -> str:
def ner_tag(self) -> NerTag:
pass
@staticmethod
......
......@@ -8,9 +8,9 @@ import torch
from transformers import PreTrainedTokenizer
from ..config import CONFIG as _CONFIG
from ..search import BioNerTag as CategoryName
CategoryName = str
CategoryNames = Iterable[CategoryName]
GroupName = str
GroupNames = Iterable[GroupName]
......
......@@ -13,7 +13,7 @@ from transformers import AutoModelForTokenClassification, AutoTokenizer
from ..config import CONFIG as _CONFIG
from ..document import Document, Sentence
from ..search import TaggedSentence, NerTags
from ..search import TaggedSentence, BioNerTags
from .schedule import ScheduleName, get_schedule
from .evaluator import AggregateMeanFScoreEvaluator, FScore, CategoryMap, CategoryName
from .objective import BIOTokenPunctuationStrippingClassification
......@@ -166,11 +166,11 @@ class NerModel:
return cls(model_name_or_basename)
def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], List[NerTags]]:
def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], List[BioNerTags]]:
ner_texts, all_ner_tags = [], []
for tagged_sentence in TaggedSentence.load(tagged_sentence_basename):
ner_texts.append(tagged_sentence.sentence)
all_ner_tags.append(tagged_sentence.bio_ner_tags)
all_ner_tags.append(tagged_sentence.ner_tags)
return ner_texts, all_ner_tags
......
......@@ -7,7 +7,8 @@ from .evaluation import (
)
from .result import (
NerTags,
BioNerTag,
BioNerTags,
SearchResultList,
TaggedSentence,
)
......@@ -22,6 +23,7 @@ __all__ = [
'Evaluation',
'Search',
'SearchResultList',
'NerTags',
'BioNerTag',
'BioNerTags',
'TaggedSentence',
]
......@@ -2,6 +2,7 @@ from __future__ import annotations
from datetime import timedelta
from collections import defaultdict
from itertools import islice
import json
from logging import getLogger
from statistics import mean
......@@ -16,7 +17,7 @@ from humanize import naturaldelta
import regex
from ..document import Document, Documents, Sentence
from ..entity import Entity
from ..entity import Entity, NerTag
from ..config import CONFIG as _CONFIG
if TYPE_CHECKING: # avoid circular dependency
......@@ -35,14 +36,15 @@ EntityRepr = str
AllResults = Tuple[Dict[Entity, 'Results'], Duration, IndexName]
AllSerializedResults = List
NerTags = str
BioNerTag = str
BioNerTags = str
class TaggedSentence:
CONFIG = _CONFIG['search.TaggedSentence']
ROOT_PATH = Path(CONFIG['root_path'])
def __init__(self, sentence: Sentence, ner_tags: NerTags):
def __init__(self, sentence: Sentence, ner_tags: BioNerTags):
sentence = regex.sub(r'\n+', ' ', sentence)
assert '\n' not in sentence
assert '\n' not in ner_tags
......@@ -60,27 +62,10 @@ class TaggedSentence:
return sentence
@property
def ner_tags(self) -> NerTags:
def ner_tags(self) -> BioNerTags:
ner_tags = ' '.join(self.ner_tags_tuple)
return ner_tags
@property
def bio_ner_tags(self) -> NerTags:
previous_ner_tag = None
bio_ner_tags_list = []
for ner_tag in self.ner_tags_tuple:
if ner_tag == 'O':
bio_ner_tag = ner_tag
else:
if previous_ner_tag is None or ner_tag != previous_ner_tag:
bio_ner_tag = f'B-{ner_tag}'
else:
bio_ner_tag = f'I-{ner_tag}'
bio_ner_tags_list.append(bio_ner_tag)
previous_ner_tag = ner_tag
bio_ner_tags = ' '.join(bio_ner_tags_list)
return bio_ner_tags
@classmethod
def save(cls, basename: str, tagged_sentences: Iterable['TaggedSentence']) -> None:
sentences_filename = cls._get_sentences_filename(basename)
......@@ -150,9 +135,9 @@ class TaggedSentence:
for word, ner_tag in zip(self.sentence_tuple, self.ner_tags_tuple):
if ner_tag == 'O':
formatted_word = word
elif ner_tag == 'PER':
elif ner_tag[2:] == 'PER':
formatted_word = bold(escape(word))
elif ner_tag == 'LOC':
elif ner_tag[2:] == 'LOC':
formatted_word = italics(escape(word))
else:
raise ValueError(f'Unknown tag "{ner_tag}"')
......@@ -281,10 +266,17 @@ class SearchResultList:
len(match_and_right_context_tokens) == len(match_tokens) + len(right_context_tokens)
)
def get_bio_ner_tags(ner_tag: NerTag) -> Iterable[BioNerTag]:
assert ner_tag != 'O'
head_ner_tag, tail_ner_tag = f'B-{ner_tag}', f'I-{ner_tag}'
yield head_ner_tag
while True:
yield tail_ner_tag
ner_tags_list = (
(len(left_context_tokens) - 1) * ['O'] +
(['O'] if is_left_context_separate and left_context_tokens else []) +
len(match_tokens) * [entity.ner_tag] +
list(islice(get_bio_ner_tags(entity.ner_tag), len(match_tokens))) +
(len(right_context_tokens) - 1) * ['O'] +
(['O'] if is_right_context_separate and right_context_tokens else [])
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment