Unverified Commit cb22f6e9 authored by Vít Novotný's avatar Vít Novotný
Browse files

Add type annotations to scripts.evaluate

parent 8f3395d5
......@@ -6,6 +6,7 @@ from multiprocessing import Pool
import os.path
import re
import sys
from typing import Tuple
from pytrec_eval import parse_run
from tqdm import tqdm
......@@ -14,7 +15,7 @@ from .common import get_ndcg, get_random_ndcg
from .configuration import TASKS, USER_README_HEAD, TASK_README_HEAD
def evaluate_worker(args):
def evaluate_worker(args) -> Tuple[str, float]:
task, result_filename = args
result_name = re.sub('_', ', ', os.path.basename(result_filename)[:-4])
with open(result_filename, 'rt') as f:
......@@ -23,7 +24,7 @@ def evaluate_worker(args):
return (result_name, ndcg)
def produce_leaderboards():
def produce_leaderboards() -> None:
for task in TASKS:
if not os.path.exists(task):
continue
......@@ -66,7 +67,7 @@ def produce_leaderboards():
f_readme.write('| %.4f | %s | %s |\n' % (ndcg, result_name, user_name))
def evaluate_run(filename, subset, year, confidence=95.0):
def evaluate_run(filename, subset, year, confidence=95.0) -> Tuple[float, float]:
with open(filename, 'rt') as f:
lines = [line.strip().split() for line in f]
first_line = lines[0]
......@@ -93,16 +94,18 @@ def evaluate_run(filename, subset, year, confidence=95.0):
parsed_result[topic_id][result_id] = 1.0 / (int(rank) + rank_offset)
ndcg, interval = get_ndcg(parsed_result, task, subset, confidence=confidence)
print('%.3f, %g%% CI: [%.3f; %.3f]' % (ndcg, confidence, *interval))
return (ndcg, interval)
if __name__ == '__main__':
if len(sys.argv) == 1:
produce_leaderboards()
elif len(sys.argv) == 2:
evaluate_run(sys.argv[1], 'all', 2020)
elif len(sys.argv) == 3:
evaluate_run(sys.argv[1], sys.argv[2], 2020)
elif len(sys.argv) == 4:
evaluate_run(sys.argv[1], sys.argv[2], int(sys.argv[3]))
elif len(sys.argv) > 1 and len(sys.argv) <= 4:
if len(sys.argv) == 2:
evaluate_run(sys.argv[1], 'all', 2020)
if len(sys.argv) == 3:
evaluate_run(sys.argv[1], sys.argv[2], 2020)
if len(sys.argv) == 4:
evaluate_run(sys.argv[1], sys.argv[2], int(sys.argv[3]))
else:
raise ValueError("Usage: {} [TSV_FILE [SUBSET [YEAR]]]".format(sys.argv[0]))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment