Commit 8c134e16 authored by Zuzana Nevěřilová's avatar Zuzana Nevěřilová
Browse files

refactor evaluation data

parent d2824e1a
Loading
Loading
Loading
Loading
+31 −20
Original line number Diff line number Diff line
@@ -34,18 +34,26 @@ def read_database(database_file, project_id):

def make_dataframe(annotations, opentapioca):
    data = []
    print("opentapioca keys", opentapioca.keys())
    for annotation in annotations.values():
        text, true_labels, pred_labels = annotation
        filename = text.get("document")
        labels_tapioca = sorted(opentapioca.get('choices'))
        if filename not in opentapioca.keys():
            print("Skipping", filename, "not found in OpenTapioca predictions")
            continue
        labels_tapioca = (opentapioca.get(filename, {}).get('choices'))
        options = dict([(v.get('value'), v.get('html').split('-')[0]) for v in text.get("options")])
    #    print('options', options)
        labels_true = [v.get("value", {}).get("choices") for v in true_labels if "choices" in v.get("value",{})]
        if labels_true:
            labels_true = sorted(labels_true[0])
            labels_true = labels_true[0]
        labels_pred = [v.get("value", {}).get("choices") for v in pred_labels if "choices" in v.get("value",{})]
        labels_pred = sorted(labels_pred[0])
        print(filename)
        if labels_pred:
            labels_pred = labels_pred[0]
#        print(filename)
#        print("labels_tapi", sorted(labels_tapioca))
#        print("labels_true", sorted(labels_true))
#        print("labels_pred", sorted(labels_pred))
        for k, v in options.items():
            data.append({"filename": filename, "qname": k, "entity": v, "annotation": k in labels_true, "prediction": k in labels_pred, "opentapioca": k in labels_tapioca, "correct_my": (k in labels_pred) == (k in labels_true), "correct_tapioca": (k in labels_tapioca) == (k in labels_true)})

@@ -54,40 +62,43 @@ def make_dataframe(annotations, opentapioca):
    print(df['correct_tapioca'].groupby(df['correct_tapioca']).count())
    return df

def main(argv):
def main():
    parser = argparse.ArgumentParser(description="Evaluate NEL annotations")
    parser.add_argument("database_file", help="Path to the SQLite database file")
    parser.add_argument("project_id", help="Project ID to evaluate")
    parser.add_argument("opentapioca_file", help="OpenTapioca predictions in JSON")
    parser.add_argument("opentapioca_files", help="Directory with OpenTapioca predictions in JSON")
    parser.add_argument("output_dirname", help="Directory to save the output")
    args = parser.parse_args()

    database_file = args.database_file
    project_id = args.project_id
    opentapioca_file = args.opentapioca_file
    opentapioca_files = args.opentapioca_files
    output_dirname = args.output_dirname

    if not os.path.exists(database_file):
        print(f"Database file {database_file} does not exist.")
        sys.exit(1)

    if not os.path.exists(opentapioca_file):
        print(f"OpenTapioca file {opentapioca_file} does not exist.")
    if not os.path.exists(opentapioca_files):
        print(f"Directory with OpenTapioca files {opentapioca_files} does not exist.")
        sys.exit(1)

    opentapioca = None
    with open(opentapioca_file, 'r') as f:
    opentapioca = {}
    for opentapioca_file in os.listdir(opentapioca_files):
        with open(os.path.join(opentapioca_files, opentapioca_file), 'r', encoding='utf-8') as f:
            opentapioca_json = json.load(f)
            filename = os.path.basename(opentapioca_file).split('.')[0]
            try:
            opentapioca = opentapioca_json.get('predictions')[0]["result"][-1].get('value')
                opentapioca[filename] = opentapioca_json.get('predictions')[0]["result"][-1].get('value')
            except:
            print("OpenTapioca JSON does not contain predictions or is not well formatted")
            sys.exit(1)
                pass

    if not opentapioca:
    if not opentapioca.keys():
        print("OpenTapioca predictions not found")
        sys.exit(1)

    print("OpenTapioca predictions", opentapioca)

    annotations = read_database(database_file, project_id)

    if not os.path.exists(output_dirname):
@@ -99,4 +110,4 @@ def main(argv):
        df.to_csv(f, sep='\t')

if __name__ == "__main__":
    main(sys.argv)
 No newline at end of file
    main()
 No newline at end of file