Commit 2287e9ba authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Add `labels` parameter to `NerModel.__init__()`

parent a15f4ea0
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ from __future__ import annotations

from logging import getLogger
from pathlib import Path
from typing import Tuple, List, Optional
from typing import Tuple, List, Optional, Iterable

import comet_ml  # noqa: F401
from adaptor.adapter import Adapter
@@ -38,9 +38,11 @@ class NerModel:
    SCHEDULE_NAME = CONFIG['schedule']
    NUM_VALIDATION_SAMPLES = CONFIG.getint('number_of_validation_samples')
    STOPPING_PATIENCE = CONFIG.getint('stopping_patience')
    LABELS = ['B-PER', 'I-PER', 'B-LOC', 'I-LOC', 'O']

    def __init__(self, model_name_or_basename: str):
    def __init__(self, model_name_or_basename: str, labels: Iterable[str] = LABELS):
        self.model_name_or_basename = model_name_or_basename
        self.labels = list(labels)

    @property
    def tokenizer(self) -> AutoTokenizer:
@@ -65,8 +67,8 @@ class NerModel:
        ner_evaluator = MeanFScore()
        ner_objective = TokenClassification(lang_module,
                                            batch_size=1,
                                            texts_or_path=[],
                                            labels_or_path=[],
                                            texts_or_path=['placeholder text'],
                                            labels_or_path=[' '.join(self.labels)],
                                            val_texts_or_path=ner_test_texts,
                                            val_labels_or_path=ner_test_labels,
                                            val_evaluators=[ner_evaluator])