Commit b400288e authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Make `AggregateMeanFScoreEvaluator` hashable and totally ordered

parent b36d1d79
Loading
Loading
Loading
Loading
Loading
+17 −0
Original line number Diff line number Diff line
from typing import Dict, Optional, Set, List
from functools import total_ordering

from more_itertools import zip_equal
import torch
@@ -18,6 +19,7 @@ GroupMap = Dict[GroupName, Group]
FScore = float


@total_ordering
class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
    GROUPS: GroupMap = {
        'PER': {'B-PER', 'I-PER', 'B-ORG', 'I-ORG'},
@@ -70,6 +72,21 @@ class AggregateMeanFScoreEvaluator(TokenClassificationEvaluator):
        all_group_names = set(cls.GROUPS.keys())
        return all_group_names

    def __hash__(self):
        return hash(self.group_name)

    def __eq__(self, other) -> bool:
        if isinstance(other, AggregateMeanFScoreEvaluator):
            return self.group_name == other.group_name
        return NotImplemented

    def __lt__(self, other) -> bool:
        if isinstance(other, AggregateMeanFScoreEvaluator):
            if self.group_name is None or other.group_name is None:
                return self.group_name is None
            return self.group_name < other.group_name
        return NotImplemented

    def __str__(self) -> str:
        if self.group_name is None:
            return 'all'