Commit 2287e9ba authored by Vít Novotný's avatar Vít Novotný
Browse files

Add `labels` parameter to `NerModel.__init__()`

parent a15f4ea0
......@@ -2,7 +2,7 @@ from __future__ import annotations
from logging import getLogger
from pathlib import Path
from typing import Tuple, List, Optional
from typing import Tuple, List, Optional, Iterable
import comet_ml # noqa: F401
from adaptor.adapter import Adapter
......@@ -38,9 +38,11 @@ class NerModel:
SCHEDULE_NAME = CONFIG['schedule']
NUM_VALIDATION_SAMPLES = CONFIG.getint('number_of_validation_samples')
STOPPING_PATIENCE = CONFIG.getint('stopping_patience')
LABELS = ['B-PER', 'I-PER', 'B-LOC', 'I-LOC', 'O']
def __init__(self, model_name_or_basename: str):
def __init__(self, model_name_or_basename: str, labels: Iterable[str] = LABELS):
self.model_name_or_basename = model_name_or_basename
self.labels = list(labels)
@property
def tokenizer(self) -> AutoTokenizer:
......@@ -65,8 +67,8 @@ class NerModel:
ner_evaluator = MeanFScore()
ner_objective = TokenClassification(lang_module,
batch_size=1,
texts_or_path=[],
labels_or_path=[],
texts_or_path=['placeholder text'],
labels_or_path=[' '.join(self.labels)],
val_texts_or_path=ner_test_texts,
val_labels_or_path=ner_test_labels,
val_evaluators=[ner_evaluator])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment