Commit abb3e082 authored by Vít Novotný's avatar Vít Novotný
Browse files

Fix typos in `ahisto_named_entity_search.recognition.model`

parent 47e65c62
......@@ -22,6 +22,9 @@ from .schedule import ScheduleName, get_schedule
LOGGER = getLogger(__name__)
EvaluationResult = float
class NerModel:
CONFIG = _CONFIG['recognition.NerModel']
ROOT_PATH = Path(CONFIG['root_path'])
......@@ -49,25 +52,28 @@ class NerModel:
def __repr__(self) -> str:
return '{}: {}'.format(self.__class__.__name__, self)
def test(self, test_tagged_sentence_basename: str) -> float:
def test(self, test_tagged_sentence_basename: str) -> EvaluationResult:
lang_module = LangModule(self.model_name_or_basename)
ner_test_texts, ner_test_labels = load_ner_dataset(test_tagged_sentence_basename)
ner_evaluators = [MeanFScore(evaluation_strategy='steps')]
ner_evaluator = MeanFScore()
ner_objective = TokenClassification(lang_module,
batch_size=1,
texts_or_path=ner_test_texts,
labels_or_path=ner_test_labels,
val_texts_or_path=ner_testn_texts,
val_labels_or_path=ner_testn_labels,
val_evaluators=ner_evaluators)
adaptation_arguments = AdaptationArguments()
texts_or_path=[],
labels_or_path=[],
val_texts_or_path=ner_test_texts,
val_labels_or_path=ner_test_labels,
val_evaluators=[ner_evaluator])
adaptation_arguments = AdaptationArguments(
output_dir='.',
stopping_strategy=StoppingStrategy.FIRST_OBJECTIVE_CONVERGED,
evaluation_strategy='steps',
)
schedule = get_schedule('sequential', [ner_objective], adaptation_arguments)
adapter = Adapter(lang_module, schedule, adaptation_arguments)
test_result = adapter.evaluate()
test_f_score = test_result['eval_MeanFScore']
test_result = ner_objective.per_objective_log("eval")
test_f_score = test_result[f'eval_{ner_objective}_{ner_evaluator}']
return test_f_score
@classmethod
......@@ -139,7 +145,7 @@ class NerModel:
return cls.load(model_basename)
@classmethod
def load(cls, basename: str) -> 'NerModel':
def load(cls, model_basename: str) -> 'NerModel':
model_pathname = cls.ROOT_PATH / model_basename
model_pathname = model_pathname / 'TokenClassification'
model_name_or_basename = str(model_pathname)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment