From e6630c66843fb2a33fe6f7a2f40a0f1e86530958 Mon Sep 17 00:00:00 2001 From: Vit Novotny <witiko@mail.muni.cz> Date: Tue, 27 Jul 2021 10:43:23 +0200 Subject: [PATCH] Add and evaluate layout detection --- requirements.txt | 1 + scripts/combine_tesseract_with_google.py | 121 +++++++++++++++++++++++ scripts/common.py | 43 ++++++++ 3 files changed, 165 insertions(+) create mode 100644 scripts/combine_tesseract_with_google.py diff --git a/requirements.txt b/requirements.txt index 818df81c..46b7498c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ mysqlclient~=1.4.6 numpy~=1.19.0 opencv-python~=4.4.0.42 pycountry~=20.7.3 +scikit-learn~=0.24.2 scipy~=1.5.0 tqdm~=4.46.1 diff --git a/scripts/combine_tesseract_with_google.py b/scripts/combine_tesseract_with_google.py new file mode 100644 index 00000000..0234454e --- /dev/null +++ b/scripts/combine_tesseract_with_google.py @@ -0,0 +1,121 @@ +# -*- coding:utf-8 -*- + +import logging + + +LOGGING_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO, format=LOGGING_FORMAT) + + +from multiprocessing import Pool +from pathlib import Path +import sys + +from tqdm import tqdm + +from .common import is_multicolumn, read_facts + + +INPUT_UPSCALED_HIGH_CONFIDENCE_FILENAMES = sys.argv[1] +INPUT_OCR_ROOT_TESSERACT = Path(sys.argv[2]) +INPUT_OCR_ROOT_GOOGLE = Path(sys.argv[3]) +OUTPUT_OCR_ROOT = Path(sys.argv[4]) + + +def read_texts_worker(args): + downscaled_input_filename, input_basename = args + input_hocr_filename = (INPUT_OCR_ROOT_TESSERACT / input_basename).with_suffix('.hocr') + try: + use_tesseract = is_multicolumn(input_hocr_filename) + except IOError: + return 'no-hocr' + if use_tesseract: + input_filename = (INPUT_OCR_ROOT_TESSERACT / input_basename).with_suffix('.txt') + else: + input_filename = (INPUT_OCR_ROOT_GOOGLE / input_basename).with_suffix('.txt') + try: + ocr_output = input_filename.open('rt').read() + except IOError: + return 'no-txt-tesseract' if use_tesseract else 'no-txt-google' + return (input_basename, ocr_output, use_tesseract) + + +def read_texts(): + logger = logging.getLogger('read_texts') + + num_successful = 0 + num_no_hocr = 0 + num_no_txt_tesseract = 0 + num_no_txt_google = 0 + + num_tesseract = 0 + num_google = 0 + + facts = list(read_facts(INPUT_UPSCALED_HIGH_CONFIDENCE_FILENAMES)) + with Pool(None) as pool: + texts = pool.imap_unordered(read_texts_worker, facts) + texts = tqdm(texts, desc='Reading OCR outputs', total=len(facts)) + for result in texts: + if result == 'no-hocr': + num_no_hocr += 1 + continue + if result == 'no-txt-tesseract': + num_no_txt_tesseract += 1 + continue + if result == 'no-txt-google': + num_no_txt_google += 1 + continue + basename, ocr_output, use_tesseract = result + if use_tesseract: + num_tesseract += 1 + else: + num_google += 1 + yield (basename, ocr_output) + num_successful += 1 + + logger.info( + 'Read {} OCR texts, not found {} HOCR and {} TXT files ({} Tesseract + {} Google)'.format( + num_successful, + num_no_hocr, + num_no_txt_tesseract + num_no_txt_google, + num_no_txt_tesseract, + num_no_txt_google, + ) + ) + logger.info( + 'Out of the {} OCR texts, {} ({:.2f}%) were by Tesseract and {} ({:.2f}%) were by Google Vision AI'.format( + num_successful, + num_tesseract, + 100.0 * num_tesseract / num_successful, + num_google, + 100.0 * num_google / num_successful, + ) + ) + + +def write_texts_worker(args): + output_basename, ocr_output = args + output_filename = (OUTPUT_OCR_ROOT / output_basename).with_suffix('.txt') + output_dirname = output_filename.parent + output_dirname.mkdir(parents=True, exist_ok=True) + with output_filename.open('wt') as f: + print(ocr_output, file=f) + + +def write_texts(texts): + OUTPUT_OCR_ROOT.mkdir(exist_ok=True) + with Pool(None) as pool: + for _ in pool.imap_unordered(write_texts_worker, texts): + pass + + +def combine(): + texts = read_texts() + write_texts(texts) + + +if __name__ == '__main__': + combine() diff --git a/scripts/common.py b/scripts/common.py index 06c18339..8af2ad66 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -21,11 +21,15 @@ from lxml import etree import numpy as np import pycountry import scipy.stats as st +from sklearn.svm import OneClassSVM +from sklearn.cluster import KMeans +from sklearn.metrics import silhouette_score from .configuration import CSV_PARAMETERS, SQL_PARAMETERS, JSON_PARAMETERS, PREPROCESSED_IMAGE_WIDTH, PREPROCESSED_IMAGE_HEIGHT, ANNOY_N_TREES DATABASE = None +BBOX_REGEX = re.compile(r'(^|;)\s*bbox\s+(?P<x0>[0-9]+)\s+(?P<y0>[0-9]+)\s+(?P<x1>[0-9]+)\s+(?P<y1>[0-9]+)\s*($|;)') URL_REGEX = re.compile(r'https://sources.cms.flu.cas.cz/.*&bookid=(?P<book_id>[0-9]+).*') IMAGE_FILENAME_REGEX = re.compile(r'\./(?P<book_id>[0-9]+)(_delete)?/(?P<page>[0-9]+)\.(jpg|png|tif)') RELEVANT_PAGE_ANNOTATION_REGEX = re.compile(r'\* *(?P<pages>([0-9]+(-[0-9]+)?(; *)?)+)') @@ -602,6 +606,45 @@ def print_confidence_interval(file, sample, name=None, unit=None, confidence=95. print(file=file) +def is_multicolumn(filename): + return get_number_of_columns(filename) > 2 + + +def get_number_of_columns(filename, ks=range(2, 10)): + with filename.open('rb') as f: + html5_parser = etree.HTMLParser(huge_tree=True) + xml_document = etree.parse(f, html5_parser) + + left_boundaries, right_boundaries = [], [] + for line in xml_document.xpath('//span[@class="ocr_line" and @title]'): + match = re.match(BBOX_REGEX, line.attrib['title']) + assert match is not None, line.attrib['title'] + left_boundary, right_boundary = float(match.group('x0')), float(match.group('x1')) + left_boundaries.append(left_boundary) + right_boundaries.append(right_boundary) + + boundaries = left_boundaries + right_boundaries + are_outliers = OneClassSVM().fit_predict(np.array(boundaries).reshape(-1, 1)) + boundaries = [ + boundary + for boundary, is_outlier + in zip(boundaries, are_outliers) + if is_outlier == 1 + ] + num_unique_boundaries = len(set(boundaries)) + X = np.array(boundaries).reshape(-1, 1) + best_k, best_silhouette = 1, float('-inf') + for k in ks: + if k >= num_unique_boundaries: + break + y = KMeans(n_clusters=k).fit_predict(X) + silhouette = silhouette_score(X, y) + if silhouette > best_silhouette: + best_k, best_silhouette = k, silhouette + + return best_k + + def normalize_language_code(language_code): try: language = pycountry.languages.lookup(language_code) -- GitLab