Commit b36d1d79 authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Rename `AggregateMeanFScore` to `AggregateMeanFScoreEvaluator`

parent 6c760a64
Loading
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
from .evaluator import (
    AggregateMeanFScore,
    AggregateMeanFScoreEvaluator,
)

from .model import (
@@ -14,7 +14,7 @@ from .schedule import (


__all__ = [
    'AggregateMeanFScore',
    'AggregateMeanFScoreEvaluator',
    'get_schedule',
    'NerModel',
    'ScheduleName',
+1 −1
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@ GroupMap = Dict[GroupName, Group]
FScore = float


class AggregateMeanFScore(TokenClassificationEvaluator):
class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
    GROUPS: GroupMap = {
        'PER': {'B-PER', 'I-PER', 'B-ORG', 'I-ORG'},
        'LOC': {'B-LOC', 'I-LOC', 'B-ORG', 'I-ORG'},
+6 −6
Original line number Diff line number Diff line
@@ -16,13 +16,13 @@ from ..config import CONFIG as _CONFIG
from ..document import Document, Sentence
from ..search import TaggedSentence, NerTags
from .schedule import ScheduleName, get_schedule
from .evaluator import AggregateMeanFScore, FScore, CategoryMap, CategoryName
from .evaluator import AggregateMeanFScoreEvaluator, FScore, CategoryMap, CategoryName


LOGGER = getLogger(__name__)


EvaluationResult = Dict[AggregateMeanFScore, FScore]
EvaluationResult = Dict[AggregateMeanFScoreEvaluator, FScore]


class NerModel:
@@ -172,12 +172,12 @@ def load_ner_dataset(tagged_sentence_basename: str) -> Tuple[List[Sentence], Lis
    return ner_texts, all_ner_tags


def get_evaluators(labels: Iterable[str]) -> Iterable[AggregateMeanFScore]:
def get_evaluators(labels: Iterable[str]) -> Iterable[AggregateMeanFScoreEvaluator]:
    category_map: CategoryMap = {
        category: category_index
        for category_index, category
        in enumerate(sorted(labels))
    }
    for group_name in AggregateMeanFScore.get_all_group_names():
        yield AggregateMeanFScore(category_map, group_name, decides_convergence=False)
    yield AggregateMeanFScore(category_map, None, decides_convergence=True)
    for group_name in AggregateMeanFScoreEvaluator.get_all_group_names():
        yield AggregateMeanFScoreEvaluator(category_map, group_name, decides_convergence=False)
    yield AggregateMeanFScoreEvaluator(category_map, None, decides_convergence=True)