Commit 3e6d925b authored by Vít Starý Novotný's avatar Vít Starý Novotný
Browse files

Load predefined classifier

parent 701f8417
Loading
Loading
Loading
Loading
Loading
+17 −8
Original line number Diff line number Diff line
@@ -400,12 +400,21 @@ def remove_index_paragraphs(paragraphs: Iterable[List[Word]], classifier: Pipeli
            yield paragraph


def main(root_directory: Path, ground_truths: Iterable[Path], output_classifier_file: Path, input_vert_file: Path, num_input_lines: int, output_basename: Path, num_output_xlsx_rows: int, num_output_most_common_xlsx_rows: int) -> None:
def load_classifier(root_directory: Path, ground_truths: Iterable[Path], classifier_file: Path, pickle_protocol: int = 0) -> Pipeline:
    try:
        with classifier_file.open('rb') as f:
            classifier = pickle.load(f)
    except IOError:
        paragraph_classification = import_module('.07_paragraph_classification', package='scripts')
        train_classifier = paragraph_classification.train_classifier
        classifier = train_classifier(root_directory, ground_truths)
    with output_classifier_file.open('wb') as f:
        pickle.dump(classifier, f, protocol=0)
        with classifier_file.open('wb') as f:
            pickle.dump(classifier, f, protocol=pickle_protocol)
    return classifier


def main(root_directory: Path, ground_truths: Iterable[Path], classifier_file: Path, input_vert_file: Path, num_input_lines: int, output_basename: Path, num_output_xlsx_rows: int, num_output_most_common_xlsx_rows: int) -> None:
    classifier = load_classifier(root_directory, ground_truths, classifier_file)
    paragraphs = read_vert_file(input_vert_file, num_input_lines)
    paragraphs = (paragraph for book, page, paragraph in paragraphs)
    paragraphs = remove_index_paragraphs(paragraphs, classifier)
@@ -424,10 +433,10 @@ if __name__ == '__main__':
    assert len(sys.argv) == 9
    root_directory = Path(sys.argv[1])
    ground_truths = map(Path, json.loads(sys.argv[2]))
    output_classifier_file = Path(sys.argv[3])
    classifier_file = Path(sys.argv[3])
    input_vert_file = Path(sys.argv[4])
    num_input_lines = int(sys.argv[5])
    output_basename = Path(sys.argv[6])
    num_output_xlsx_rows = int(sys.argv[7])
    num_output_most_common_xlsx_rows = int(sys.argv[8])
    main(root_directory, ground_truths, output_classifier_file, input_vert_file, num_input_lines, output_basename, num_output_xlsx_rows, num_output_most_common_xlsx_rows)
    main(root_directory, ground_truths, classifier_file, input_vert_file, num_input_lines, output_basename, num_output_xlsx_rows, num_output_most_common_xlsx_rows)