Commit 6c760a64 authored by Vít Novotný's avatar Vít Novotný
Browse files

Explicitly use sets in `ahisto_named_entity_search.recognition.evaluator`

parent 7e5701eb
Pipeline #147226 passed with stage
in 8 minutes and 51 seconds
from typing import Dict, Optional, Set, Iterable, List
from typing import Dict, Optional, Set, List
from more_itertools import zip_equal
import torch
......@@ -11,7 +11,7 @@ from adaptor.utils import AdaptationDataset
CategoryName = str
GroupName = str
Category = int
Group = Iterable[CategoryName]
Group = Set[CategoryName]
CategoryMap = Dict[CategoryName, Category]
GroupMap = Dict[GroupName, Group]
......@@ -20,9 +20,9 @@ FScore = float
class AggregateMeanFScore(TokenClassificationEvaluator):
GROUPS: GroupMap = {
'PER': ('B-PER', 'I-PER', 'B-ORG', 'I-ORG'),
'LOC': ('B-LOC', 'I-LOC', 'B-ORG', 'I-ORG'),
'O': ('O', 'B-MISC', 'I-MISC'),
'PER': {'B-PER', 'I-PER', 'B-ORG', 'I-ORG'},
'LOC': {'B-LOC', 'I-LOC', 'B-ORG', 'I-ORG'},
'O': {'O', 'B-MISC', 'I-MISC'},
def __init__(self, category_map: CategoryMap, group_name: Optional[GroupName],
......@@ -66,9 +66,9 @@ class AggregateMeanFScore(TokenClassificationEvaluator):
return f_score
def get_all_group_names(cls) -> Iterable[GroupName]:
for group_name in cls.GROUPS.keys():
yield group_name
def get_all_group_names(cls) -> Set[GroupName]:
all_group_names = set(cls.GROUPS.keys())
return all_group_names
def __str__(self) -> str:
if self.group_name is None:
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment