From e6630c66843fb2a33fe6f7a2f40a0f1e86530958 Mon Sep 17 00:00:00 2001
From: Vit Novotny <witiko@mail.muni.cz>
Date: Tue, 27 Jul 2021 10:43:23 +0200
Subject: [PATCH] Add and evaluate layout detection

---
 requirements.txt                         |   1 +
 scripts/combine_tesseract_with_google.py | 121 +++++++++++++++++++++++
 scripts/common.py                        |  43 ++++++++
 3 files changed, 165 insertions(+)
 create mode 100644 scripts/combine_tesseract_with_google.py

diff --git a/requirements.txt b/requirements.txt
index 818df81c..46b7498c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,5 +7,6 @@ mysqlclient~=1.4.6
 numpy~=1.19.0
 opencv-python~=4.4.0.42
 pycountry~=20.7.3
+scikit-learn~=0.24.2
 scipy~=1.5.0
 tqdm~=4.46.1
diff --git a/scripts/combine_tesseract_with_google.py b/scripts/combine_tesseract_with_google.py
new file mode 100644
index 00000000..0234454e
--- /dev/null
+++ b/scripts/combine_tesseract_with_google.py
@@ -0,0 +1,121 @@
+# -*- coding:utf-8 -*-
+
+import logging
+
+
+LOGGING_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+
+
+if __name__ == '__main__':
+    logging.basicConfig(level=logging.INFO, format=LOGGING_FORMAT)
+
+
+from multiprocessing import Pool
+from pathlib import Path
+import sys
+
+from tqdm import tqdm
+
+from .common import is_multicolumn, read_facts
+
+
+INPUT_UPSCALED_HIGH_CONFIDENCE_FILENAMES = sys.argv[1]
+INPUT_OCR_ROOT_TESSERACT = Path(sys.argv[2])
+INPUT_OCR_ROOT_GOOGLE = Path(sys.argv[3])
+OUTPUT_OCR_ROOT = Path(sys.argv[4])
+
+
+def read_texts_worker(args):
+    downscaled_input_filename, input_basename = args
+    input_hocr_filename = (INPUT_OCR_ROOT_TESSERACT / input_basename).with_suffix('.hocr')
+    try:
+        use_tesseract = is_multicolumn(input_hocr_filename)
+    except IOError:
+        return 'no-hocr'
+    if use_tesseract:
+        input_filename = (INPUT_OCR_ROOT_TESSERACT / input_basename).with_suffix('.txt')
+    else:
+        input_filename = (INPUT_OCR_ROOT_GOOGLE / input_basename).with_suffix('.txt')
+    try:
+        ocr_output = input_filename.open('rt').read()
+    except IOError:
+        return 'no-txt-tesseract' if use_tesseract else 'no-txt-google'
+    return (input_basename, ocr_output, use_tesseract)
+
+
+def read_texts():
+    logger = logging.getLogger('read_texts')
+
+    num_successful = 0
+    num_no_hocr = 0
+    num_no_txt_tesseract = 0
+    num_no_txt_google = 0
+
+    num_tesseract = 0
+    num_google = 0
+
+    facts = list(read_facts(INPUT_UPSCALED_HIGH_CONFIDENCE_FILENAMES))
+    with Pool(None) as pool:
+        texts = pool.imap_unordered(read_texts_worker, facts)
+        texts = tqdm(texts, desc='Reading OCR outputs', total=len(facts))
+        for result in texts:
+            if result == 'no-hocr':
+                num_no_hocr += 1
+                continue
+            if result == 'no-txt-tesseract':
+                num_no_txt_tesseract += 1
+                continue
+            if result == 'no-txt-google':
+                num_no_txt_google += 1
+                continue
+            basename, ocr_output, use_tesseract = result
+            if use_tesseract:
+                num_tesseract += 1
+            else:
+                num_google += 1
+            yield (basename, ocr_output)
+            num_successful += 1
+
+    logger.info(
+        'Read {} OCR texts, not found {} HOCR and {} TXT files ({} Tesseract + {} Google)'.format(
+            num_successful,
+            num_no_hocr,
+            num_no_txt_tesseract + num_no_txt_google,
+            num_no_txt_tesseract,
+            num_no_txt_google,
+        )
+    )
+    logger.info(
+        'Out of the {} OCR texts, {} ({:.2f}%) were by Tesseract and {} ({:.2f}%) were by Google Vision AI'.format(
+            num_successful,
+            num_tesseract,
+            100.0 * num_tesseract / num_successful,
+            num_google,
+            100.0 * num_google / num_successful,
+        )
+    )
+
+
+def write_texts_worker(args):
+    output_basename, ocr_output = args
+    output_filename = (OUTPUT_OCR_ROOT / output_basename).with_suffix('.txt')
+    output_dirname = output_filename.parent
+    output_dirname.mkdir(parents=True, exist_ok=True)
+    with output_filename.open('wt') as f:
+        print(ocr_output, file=f)
+
+
+def write_texts(texts):
+    OUTPUT_OCR_ROOT.mkdir(exist_ok=True)
+    with Pool(None) as pool:
+        for _ in pool.imap_unordered(write_texts_worker, texts):
+            pass
+
+
+def combine():
+    texts = read_texts()
+    write_texts(texts)
+
+
+if __name__ == '__main__':
+    combine()
diff --git a/scripts/common.py b/scripts/common.py
index 06c18339..8af2ad66 100644
--- a/scripts/common.py
+++ b/scripts/common.py
@@ -21,11 +21,15 @@ from lxml import etree
 import numpy as np
 import pycountry
 import scipy.stats as st
+from sklearn.svm import OneClassSVM
+from sklearn.cluster import KMeans
+from sklearn.metrics import silhouette_score
 
 from .configuration import CSV_PARAMETERS, SQL_PARAMETERS, JSON_PARAMETERS, PREPROCESSED_IMAGE_WIDTH, PREPROCESSED_IMAGE_HEIGHT, ANNOY_N_TREES
 
 
 DATABASE = None
+BBOX_REGEX = re.compile(r'(^|;)\s*bbox\s+(?P<x0>[0-9]+)\s+(?P<y0>[0-9]+)\s+(?P<x1>[0-9]+)\s+(?P<y1>[0-9]+)\s*($|;)')
 URL_REGEX = re.compile(r'https://sources.cms.flu.cas.cz/.*&bookid=(?P<book_id>[0-9]+).*')
 IMAGE_FILENAME_REGEX = re.compile(r'\./(?P<book_id>[0-9]+)(_delete)?/(?P<page>[0-9]+)\.(jpg|png|tif)')
 RELEVANT_PAGE_ANNOTATION_REGEX = re.compile(r'\* *(?P<pages>([0-9]+(-[0-9]+)?(; *)?)+)')
@@ -602,6 +606,45 @@ def print_confidence_interval(file, sample, name=None, unit=None, confidence=95.
     print(file=file)
 
 
+def is_multicolumn(filename):
+    return get_number_of_columns(filename) > 2
+
+
+def get_number_of_columns(filename, ks=range(2, 10)):
+    with filename.open('rb') as f:
+        html5_parser = etree.HTMLParser(huge_tree=True)
+        xml_document = etree.parse(f, html5_parser)
+
+    left_boundaries, right_boundaries = [], []
+    for line in xml_document.xpath('//span[@class="ocr_line" and @title]'):
+        match = re.match(BBOX_REGEX, line.attrib['title'])
+        assert match is not None, line.attrib['title']
+        left_boundary, right_boundary = float(match.group('x0')), float(match.group('x1'))
+        left_boundaries.append(left_boundary)
+        right_boundaries.append(right_boundary)
+
+    boundaries = left_boundaries + right_boundaries
+    are_outliers = OneClassSVM().fit_predict(np.array(boundaries).reshape(-1, 1))
+    boundaries = [
+        boundary
+        for boundary, is_outlier
+        in zip(boundaries, are_outliers)
+        if is_outlier == 1
+    ]
+    num_unique_boundaries = len(set(boundaries))
+    X = np.array(boundaries).reshape(-1, 1)
+    best_k, best_silhouette = 1, float('-inf')
+    for k in ks:
+        if k >= num_unique_boundaries:
+            break
+        y = KMeans(n_clusters=k).fit_predict(X)
+        silhouette = silhouette_score(X, y)
+        if silhouette > best_silhouette:
+            best_k, best_silhouette = k, silhouette
+
+    return best_k
+
+
 def normalize_language_code(language_code):
     try:
         language = pycountry.languages.lookup(language_code)
-- 
GitLab