Commit f8df7140 authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Include BIO tags directly in produced datasets

parent c0f68d9c
Loading
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
from .entity import (
    Entity,
    load_entities,
    NerTag,
    Person,
    Place,
    Regest,
@@ -27,6 +28,7 @@ __all__ = [
    'Entity',
    'EntityType',
    'load_entities',
    'NerTag',
    'Patch',
    'PatchConfirmatory',
    'Patchset',
+4 −1
Original line number Diff line number Diff line
@@ -23,6 +23,9 @@ LOGGER = getLogger(__name__)
COMPOSITE_ENTITY = re.compile(r'(?P<before_parens>.*)\((?P<inside_parens>[^)]*)\)\s*')


NerTag = str


@total_ordering
class Regest:
    ROOT_PATH = Path(CONFIG['root_path'])
@@ -109,7 +112,7 @@ class Entity(ABC):

    @property
    @abstractmethod
    def ner_tag(self) -> str:
    def ner_tag(self) -> NerTag:
        pass

    @staticmethod
+1 −1
Original line number Diff line number Diff line
@@ -8,9 +8,9 @@ import torch
from transformers import PreTrainedTokenizer

from ..config import CONFIG as _CONFIG
from ..search import BioNerTag as CategoryName


CategoryName = str
CategoryNames = Iterable[CategoryName]
GroupName = str
GroupNames = Iterable[GroupName]
+3 −3
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ from transformers import AutoModelForTokenClassification, AutoTokenizer

from ..config import CONFIG as _CONFIG
from ..document import Document, Sentence
from ..search import TaggedSentence, NerTags
from ..search import TaggedSentence, BioNerTags
from .schedule import ScheduleName, get_schedule
from .evaluator import AggregateMeanFScoreEvaluator, FScore, CategoryMap, CategoryName
from .objective import BIOTokenPunctuationStrippingClassification
@@ -166,11 +166,11 @@ class NerModel:
        return cls(model_name_or_basename)


def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], List[NerTags]]:
def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], List[BioNerTags]]:
    ner_texts, all_ner_tags = [], []
    for tagged_sentence in TaggedSentence.load(tagged_sentence_basename):
        ner_texts.append(tagged_sentence.sentence)
        all_ner_tags.append(tagged_sentence.bio_ner_tags)
        all_ner_tags.append(tagged_sentence.ner_tags)
    return ner_texts, all_ner_tags


+4 −2
Original line number Diff line number Diff line
@@ -7,7 +7,8 @@ from .evaluation import (
)

from .result import (
    NerTags,
    BioNerTag,
    BioNerTags,
    SearchResultList,
    TaggedSentence,
)
@@ -22,6 +23,7 @@ __all__ = [
    'Evaluation',
    'Search',
    'SearchResultList',
    'NerTags',
    'BioNerTag',
    'BioNerTags',
    'TaggedSentence',
]
Loading