diff --git a/src/eynollah/cli/__init__.py b/src/eynollah/cli/__init__.py index 43ed046..1584fa5 100644 --- a/src/eynollah/cli/__init__.py +++ b/src/eynollah/cli/__init__.py @@ -1,7 +1,3 @@ -# NOTE: For predictable order of imports of torch/shapely/tensorflow -# this must be the first import of the CLI! -from ..eynollah_imports import imported_libs - from .cli import main from .cli_binarize import binarize_cli from .cli_enhance import enhance_cli diff --git a/src/eynollah/cli/cli.py b/src/eynollah/cli/cli.py index ace3f1c..2a4c8d1 100644 --- a/src/eynollah/cli/cli.py +++ b/src/eynollah/cli/cli.py @@ -15,6 +15,7 @@ class EynollahCliCtx: Holds options relevant for all eynollah subcommands """ model_zoo: EynollahModelZoo + device: str = '' log_level : Union[str, None] = 'INFO' @@ -35,6 +36,11 @@ class EynollahCliCtx: type=(str, str, str), multiple=True, ) +@click.option( + "--device", + "-D", + help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')", +) @click.option( "--log_level", "-l", @@ -42,7 +48,7 @@ class EynollahCliCtx: help="Override log level globally to this", ) @click.pass_context -def main(ctx, model_basedir, model_overrides, log_level): +def main(ctx, model_basedir, model_overrides, device, log_level): """ eynollah - Document Layout Analysis, Image Enhancement, OCR """ @@ -58,6 +64,7 @@ def main(ctx, model_basedir, model_overrides, log_level): # Initialize CLI context ctx.obj = EynollahCliCtx( model_zoo=model_zoo, + device=device, log_level=log_level, ) diff --git a/src/eynollah/cli/cli_binarize.py b/src/eynollah/cli/cli_binarize.py index f0e56f5..82209be 100644 --- a/src/eynollah/cli/cli_binarize.py +++ b/src/eynollah/cli/cli_binarize.py @@ -1,6 +1,8 @@ import click -@click.command() +@click.command(context_settings=dict( + help_option_names=['-h', '--help'], + show_default=True)) @click.option( '--patches/--no-patches', default=True, @@ -31,11 +33,6 @@ import click help="overwrite (instead of skipping) if output xml exists", is_flag=True, ) -@click.option( - "--device", - "-D", - help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')", -) @click.pass_context def binarize_cli( ctx, @@ -44,14 +41,14 @@ def binarize_cli( dir_in, output, overwrite, - device, ): """ Binarize images with a ML model """ from ..sbb_binarize import SbbBinarizer assert bool(input_image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo, device=device) + binarizer = SbbBinarizer(model_zoo=ctx.obj.model_zoo, + device=ctx.obj.device) binarizer.run( image_filename=input_image, use_patches=patches, diff --git a/src/eynollah/cli/cli_enhance.py b/src/eynollah/cli/cli_enhance.py index 517e1e8..bcb8263 100644 --- a/src/eynollah/cli/cli_enhance.py +++ b/src/eynollah/cli/cli_enhance.py @@ -1,6 +1,8 @@ import click -@click.command() +@click.command(context_settings=dict( + help_option_names=['-h', '--help'], + show_default=True)) @click.option( "--image", "-i", @@ -46,13 +48,8 @@ import click is_flag=True, help="save the enhanced image in original image size", ) -@click.option( - "--device", - "-D", - help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')", -) @click.pass_context -def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale, device): +def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower, save_org_scale): """ Enhance image """ @@ -60,10 +57,10 @@ def enhance_cli(ctx, image, out, overwrite, dir_in, num_col_upper, num_col_lower from ..image_enhancer import Enhancer enhancer = Enhancer( model_zoo=ctx.obj.model_zoo, + device=ctx.obj.device, num_col_upper=num_col_upper, num_col_lower=num_col_lower, save_org_scale=save_org_scale, - device=device, ) enhancer.run(overwrite=overwrite, dir_in=dir_in, diff --git a/src/eynollah/cli/cli_extract_images.py b/src/eynollah/cli/cli_extract_images.py index 0add5b5..acd31f1 100644 --- a/src/eynollah/cli/cli_extract_images.py +++ b/src/eynollah/cli/cli_extract_images.py @@ -1,6 +1,8 @@ import click -@click.command() +@click.command(context_settings=dict( + help_option_names=['-h', '--help'], + show_default=True)) @click.option( "--image", "-i", @@ -30,36 +32,40 @@ import click @click.option( "--save_images", "-si", - help="if a directory is given, images in documents will be cropped and saved there", + help="if a directory is given, cropped images of pages will be saved there", type=click.Path(exists=True, file_okay=False), ) @click.option( - "--enable-plotting/--disable-plotting", - "-ep/-noep", + "--enable-plotting", + "-ep", is_flag=True, - help="If set, will plot intermediary files and images", + help="plot intermediary diagnostic images to files", ) @click.option( - "--input_binary/--input-RGB", - "-ib/-irgb", + "--input_binary", + "-ib", is_flag=True, - help="In general, eynollah uses RGB as input but if the input document is very dark, very bright or for any other reason you can turn on input binarization. When this flag is set, eynollah will binarize the RGB input document, you should always provide RGB images to eynollah.", + help="In general, eynollah uses RGB as input, but if the input document is very dark, very bright or for any other reason you can turn on internal binarization here. When set, eynollah will binarize the RGB input document first.", ) @click.option( - "--ignore_page_extraction/--extract_page_included", - "-ipe/-epi", + "--ignore_page_extraction", + "-ipe", is_flag=True, - help="if this parameter set to true, this tool would ignore page extraction", + help="ignore page extraction (cropping via page frame detection model)", ) @click.option( "--num_col_upper", "-ncu", - help="lower limit of columns in document image", + default=0, + type=click.IntRange(min=0), + help="lower limit of columns in document image; 0 means autodetected from model", ) @click.option( "--num_col_lower", "-ncl", - help="upper limit of columns in document image", + default=0, + type=click.IntRange(min=0), + help="upper limit of columns in document image; 0 means autodetected from model", ) @click.pass_context def extract_images_cli( diff --git a/src/eynollah/cli/cli_layout.py b/src/eynollah/cli/cli_layout.py index 417b202..0a083d5 100644 --- a/src/eynollah/cli/cli_layout.py +++ b/src/eynollah/cli/cli_layout.py @@ -172,11 +172,6 @@ import click type=click.FloatRange(min=0), help="abort when number of failed images exceeds this value (if >=1) or ratio of failed over total images exceeds this value (if <1); 0 means ignore failures", ) -@click.option( - "--device", - "-D", - help="placement of computations in predictors for each model type; if none (by default), will try to use first available GPU or fall back to CPU; set string to force using a device (e.g. 'GPU0', 'GPU1' or 'CPU'). Can also be a comma-separated list of model category to device mappings (e.g. 'col_classifier:CPU,page:GPU0,*:GPU1')", -) @click.pass_context def layout_cli( ctx, @@ -207,7 +202,6 @@ def layout_cli( ignore_page_extraction, num_jobs, halt_fail, - device, ): """ Detect Layout (with optional image enhancement and reading order detection) @@ -223,7 +217,7 @@ def layout_cli( assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." eynollah = Eynollah( model_zoo=ctx.obj.model_zoo, - device=device, + device=ctx.obj.device, enable_plotting=enable_plotting, allow_enhancement=allow_enhancement, curved_line=curved_line, diff --git a/src/eynollah/cli/cli_ocr.py b/src/eynollah/cli/cli_ocr.py index 406af61..daeccbe 100644 --- a/src/eynollah/cli/cli_ocr.py +++ b/src/eynollah/cli/cli_ocr.py @@ -1,6 +1,8 @@ import click -@click.command() +@click.command(context_settings=dict( + help_option_names=['-h', '--help'], + show_default=True)) @click.option( "--image", "-i", @@ -16,7 +18,7 @@ import click @click.option( "--dir_in_bin", "-dib", - help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png' \n Perform prediction using both RGB and binary images. (This does not necessarily improve results, however it may be beneficial for certain document images."), + help=("directory of binarized images (in addition to --dir_in for RGB images; filename stems must match the RGB image files, with '.png'. \n Perform prediction using both RGB and binary images. (This may improve results for certain document images.)"), type=click.Path(exists=True, file_okay=False), ) @click.option( @@ -47,25 +49,29 @@ import click ) @click.option( "--tr_ocr", - "-trocr/-notrocr", + "-trocr", is_flag=True, - help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.", + help="use transformer OCR (instead of classic CNN-RNN) model", ) @click.option( "--do_not_mask_with_textline_contour", - "-nmtc/-mtc", + "-nmtc", is_flag=True, - help="if this parameter set to true, cropped textline images will not be masked with textline contour.", + help="skip masking each cropped textline image with its corresponding textline contour", ) @click.option( "--batch_size", "-bs", + default=0, + type=click.IntRange(min=0), help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively", ) @click.option( "--min_conf_value_of_textline_text", "-min_conf", - help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.", + default=0.3, + type=click.FloatRange(min=0.0, max=1.0), + help="minimum OCR confidence threshold. Text lines with a lower confidence value will not be included in the output XML file.", ) @click.pass_context def ocr_cli( @@ -85,14 +91,16 @@ def ocr_cli( """ Recognize text with a CNN/RNN or transformer ML model. """ - assert bool(image) ^ bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both." + assert bool(image) != bool(dir_in), "Either -i (single image) or -di (directory) must be provided, but not both." from ..eynollah_ocr import Eynollah_ocr eynollah_ocr = Eynollah_ocr( model_zoo=ctx.obj.model_zoo, + device=ctx.obj.device, tr_ocr=tr_ocr, do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, batch_size=batch_size, - min_conf_value_of_textline_text=min_conf_value_of_textline_text) + min_conf_value_of_textline_text=min_conf_value_of_textline_text, + ) eynollah_ocr.run(overwrite=overwrite, dir_in=dir_in, dir_in_bin=dir_in_bin, diff --git a/src/eynollah/cli/cli_readingorder.py b/src/eynollah/cli/cli_readingorder.py index 0f44b7f..ac52e38 100644 --- a/src/eynollah/cli/cli_readingorder.py +++ b/src/eynollah/cli/cli_readingorder.py @@ -1,6 +1,8 @@ import click -@click.command() +@click.command(context_settings=dict( + help_option_names=['-h', '--help'], + show_default=True)) @click.option( "--input", "-i", @@ -25,9 +27,10 @@ def readingorder_cli(ctx, input, dir_in, out): """ Generate ReadingOrder with a ML model """ - from ..mb_ro_on_layout import machine_based_reading_order_on_layout + from ..mb_ro_on_layout import Reorder assert bool(input) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." - orderer = machine_based_reading_order_on_layout(model_zoo=ctx.obj.model_zoo) + orderer = Reorder(model_zoo=ctx.obj.model_zoo, + device=ctx.obj.device) orderer.run(xml_filename=input, dir_in=dir_in, dir_out=out, diff --git a/src/eynollah/extract_images.py b/src/eynollah/extract_images.py index 7a7e3f6..40476a3 100644 --- a/src/eynollah/extract_images.py +++ b/src/eynollah/extract_images.py @@ -9,7 +9,6 @@ import os import time from typing import Optional from pathlib import Path -import tensorflow as tf import numpy as np import cv2 @@ -64,12 +63,6 @@ class EynollahImageExtractor(Eynollah): t_start = time.time() - try: - for device in tf.config.list_physical_devices('GPU'): - tf.config.experimental.set_memory_growth(device, True) - except: - self.logger.warning("no GPU device available") - self.logger.info("Loading models...") self.setup_models() self.logger.info(f"Model initialization complete ({time.time() - t_start:.1f}s)") diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index c632941..9db47ce 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -1148,7 +1148,6 @@ class Eynollah: boxes, textline_mask_tot ): - assert np.any(textline_mask_tot) self.logger.debug("enter do_order_of_regions") contours_only_text_parent = ensure_array(contours_only_text_parent) contours_only_text_parent_h = ensure_array(contours_only_text_parent_h) diff --git a/src/eynollah/eynollah_imports.py b/src/eynollah/eynollah_imports.py deleted file mode 100644 index 496406c..0000000 --- a/src/eynollah/eynollah_imports.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -Load libraries with possible race conditions once. This must be imported as the first module of eynollah. -""" -import os -os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 - -from ocrd_utils import tf_disable_interactive_logs -from torch import * -tf_disable_interactive_logs() -import tensorflow.keras -from shapely import * -imported_libs = True -__all__ = ['imported_libs'] diff --git a/src/eynollah/eynollah_ocr.py b/src/eynollah/eynollah_ocr.py index 1b49077..b94853b 100644 --- a/src/eynollah/eynollah_ocr.py +++ b/src/eynollah/eynollah_ocr.py @@ -14,16 +14,14 @@ from cv2.typing import MatLike from xml.etree import ElementTree as ET from PIL import Image, ImageDraw import numpy as np -from eynollah.model_zoo import EynollahModelZoo -from eynollah.utils.font import get_font -from eynollah.utils.xml import etree_namespace_for_element_tag -try: - import torch -except ImportError: - torch = None +from ocrd_utils import polygon_from_points, xywh_from_polygon +from .eynollah import Eynollah +from .model_zoo import EynollahModelZoo from .utils import is_image_filename +from .utils.font import get_font +from .utils.xml import etree_namespace_for_element_tag from .utils.resize import resize_image from .utils.utils_ocr import ( break_curved_line_into_small_pieces_and_then_merge, @@ -34,6 +32,7 @@ from .utils.utils_ocr import ( preprocess_and_resize_image_for_ocrcnn_model, return_textlines_split_if_needed, rotate_image_with_padding, + batched, ) # TODO: refine typing @@ -44,45 +43,44 @@ class EynollahOcrResult: cropped_lines_region_indexer: List total_bb_coordinates:List -class Eynollah_ocr: +class Eynollah_ocr(Eynollah): def __init__( self, *, model_zoo: EynollahModelZoo, tr_ocr=False, - batch_size: Optional[int]=None, + batch_size: int=0, do_not_mask_with_textline_contour: bool=False, - min_conf_value_of_textline_text : Optional[float]=None, + min_conf_value_of_textline_text : float=0.3, logger: Optional[Logger]=None, + device: str = '', ): self.tr_ocr = tr_ocr # masking for OCR and GT generation, relevant for skewed lines and bounding boxes self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour self.logger = logger if logger else getLogger('eynollah.ocr') - self.model_zoo = model_zoo - self.min_conf_value_of_textline_text = min_conf_value_of_textline_text if min_conf_value_of_textline_text else 0.3 - self.b_s = 2 if batch_size is None and tr_ocr else 8 if batch_size is None else batch_size + self.min_conf_value_of_textline_text = min_conf_value_of_textline_text + self.b_s = batch_size or 2 if tr_ocr else 8 - if tr_ocr: - self.model_zoo.load_models('trocr_processor') - self.model_zoo.load_models(['ocr', 'tr']) - self.model_zoo.get('ocr').to(self.device) + self.model_zoo = model_zoo + self.setup_models(device=device) + + def setup_models(self, device=''): + if self.tr_ocr: + self.model_zoo.load_models('trocr_processor', + ('ocr', 'tr'), + device=device) else: - self.model_zoo.load_models('ocr') - self.model_zoo.load_models('num_to_char') - self.model_zoo.load_models('characters') + self.model_zoo.load_models('ocr', + 'num_to_char', + 'characters', + device=device) self.end_character = len(self.model_zoo.get('characters')) + 2 @property def device(self): - assert torch - if torch.cuda.is_available(): - self.logger.info("Using GPU acceleration") - return torch.device("cuda:0") - else: - self.logger.info("Using CPU processing") - return torch.device("cpu") + return self.model_zoo.get('ocr').device def run_trocr( self, @@ -94,174 +92,94 @@ class Eynollah_ocr: ) -> EynollahOcrResult: total_bb_coordinates = [] - - cropped_lines = [] cropped_lines_region_indexer = [] cropped_lines_meging_indexing = [] - extracted_texts = [] + extracted_confs = [] - indexer_text_region = 0 - indexer_b_s = 0 - - for nn in page_tree.getroot().iter(f'{{{page_ns}}}TextRegion'): - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - - for child_textlines in child_textregion: - if child_textlines.tag.endswith("Coords"): - cropped_lines_region_indexer.append(indexer_text_region) - p_h=child_textlines.attrib['points'].split(' ') - textline_coords = np.array( [ [int(x.split(',')[0]), - int(x.split(',')[1]) ] - for x in p_h] ) - x,y,w,h = cv2.boundingRect(textline_coords) - - total_bb_coordinates.append([x,y,w,h]) - - h2w_ratio = h/float(w) - - img_poly_on_img = np.copy(img) - mask_poly = np.zeros(img.shape) - mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) - - mask_poly = mask_poly[y:y+h, x:x+w, :] - img_crop = img_poly_on_img[y:y+h, x:x+w, :] - img_crop[mask_poly==0] = 255 - - self.logger.debug("processing %d lines for '%s'", - len(cropped_lines), nn.attrib['id']) - if h2w_ratio > 0.1: - cropped_lines.append(resize_image(img_crop, - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width) ) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - splited_images, _ = return_textlines_split_if_needed(img_crop, None) - #print(splited_images) - if splited_images: - cropped_lines.append(resize_image(splited_images[0], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - cropped_lines.append(resize_image(splited_images[1], - tr_ocr_input_height_and_width, - tr_ocr_input_height_and_width)) - cropped_lines_meging_indexing.append(-1) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - else: - cropped_lines.append(img_crop) - cropped_lines_meging_indexing.append(0) - indexer_b_s+=1 - - if indexer_b_s==self.b_s: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate( - pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - - - indexer_text_region = indexer_text_region +1 + for n_region, region in enumerate(page_tree.getroot().iter('{%s}TextRegion' % page_ns)): + for n_line, line in enumerate(region.iter('{%s}TextLine' % page_ns)): + cropped_lines_region_indexer.append(n_region) - if indexer_b_s!=0: - imgs = cropped_lines[:] - cropped_lines = [] - indexer_b_s = 0 - - pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - generated_ids_merged = self.model_zoo.get('ocr').generate(pixel_values_merged.to(self.device)) - generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode(generated_ids_merged, skip_special_tokens=True) - - extracted_texts = extracted_texts + generated_text_merged - - ####extracted_texts = [] - ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) + coords = line.find('{%s}Coords' % page_ns) + if coords is None: + self.logger.warning("region '%s' line '%s' has no Coords", region.attrib['id'], line.attrib['id']) + continue + poly = np.array(polygon_from_points(coords.attrib['points'])).astype(int) + cont = poly[:, np.newaxis] + xywh = xywh_from_polygon(poly) + x, y, w, h = xywh['x'], xywh['y'], xywh['w'], xywh['h'] - ####for i in range(n_iterations): - ####if i==(n_iterations-1): - ####n_start = i*self.b_s - ####imgs = cropped_lines[n_start:] - ####else: - ####n_start = i*self.b_s - ####n_end = (i+1)*self.b_s - ####imgs = cropped_lines[n_start:n_end] - ####pixel_values_merged = self.model_zoo.get('trocr_processor')(imgs, return_tensors="pt").pixel_values - ####generated_ids_merged = self.model_ocr.generate( - #### pixel_values_merged.to(self.device)) - ####generated_text_merged = self.model_zoo.get('trocr_processor').batch_decode( - #### generated_ids_merged, skip_special_tokens=True) - - ####extracted_texts = extracted_texts + generated_text_merged - + total_bb_coordinates.append([x, y, w, h]) + + img_crop = img[y: y + h, x: x + w] + if not self.do_not_mask_with_textline_contour: + mask_poly = np.zeros(img_crop.shape[:2], dtype=np.uint8) + mask_poly = cv2.fillPoly(mask_poly, pts=[cont - [x, y]], color=1) + img_crop[mask_poly == 0] = 255 # FIXME: or median color? + + if h > 0.1 * w: + cropped_lines.append(resize_image(img_crop, + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width) ) + cropped_lines_meging_indexing.append(0) + else: + splited_images, _ = return_textlines_split_if_needed(img_crop, None) + if splited_images: + cropped_lines.append(resize_image(splited_images[0], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(1) + cropped_lines.append(resize_image(splited_images[1], + tr_ocr_input_height_and_width, + tr_ocr_input_height_and_width)) + cropped_lines_meging_indexing.append(-1) + else: + cropped_lines.append(img_crop) + cropped_lines_meging_indexing.append(0) + + + self.logger.debug("processing %d lines for %d regions", + len(cropped_lines), len(set(cropped_lines_region_indexer))) + for imgs in batched(cropped_lines, self.b_s): + pixel_values = self.model_zoo.get('trocr_processor')( + imgs, return_tensors="pt").pixel_values + output = self.model_zoo.get('ocr').generate( + pixel_values.to(self.device), + # beam search instead of greedy decoding: + num_beams=4, + # also return probability + output_scores=True, + return_dict_in_generate=True) + if output.sequences_scores is not None: + # log-prob averaged over length + conf = output.sequences_scores.exp().clamp(0.0, 1.0).tolist() + else: + conf = [1.0] * len(output.sequences) + text = self.model_zoo.get('trocr_processor').batch_decode( + output.sequences, + skip_special_tokens=True, + clean_up_tokenization_spaces=False) + extracted_confs.extend(conf) + extracted_texts.extend(text) del cropped_lines gc.collect() extracted_texts_merged = [extracted_texts[ind] - if cropped_lines_meging_indexing[ind]==0 - else extracted_texts[ind]+" "+extracted_texts[ind+1] - if cropped_lines_meging_indexing[ind]==1 - else None - for ind in range(len(cropped_lines_meging_indexing))] - - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - #print(extracted_texts_merged, len(extracted_texts_merged)) + if cropped_lines_meging_indexing[ind] == 0 + else extracted_texts[ind] + " " + extracted_texts[ind + 1] + for ind in range(len(cropped_lines_meging_indexing)) + if cropped_lines_meging_indexing[ind] >= 0] + extracted_confs_merged = [extracted_confs[ind] + if cropped_lines_meging_indexing[ind] == 0 + else 0.5 * (extracted_confs[ind] + extracted_confs[ind + 1]) + for ind in range(len(cropped_lines_meging_indexing)) + if cropped_lines_meging_indexing[ind] >= 0] return EynollahOcrResult( extracted_texts_merged=extracted_texts_merged, - extracted_conf_value_merged=None, + extracted_conf_value_merged=extracted_confs_merged, cropped_lines_region_indexer=cropped_lines_region_indexer, total_bb_coordinates=total_bb_coordinates, ) @@ -717,6 +635,7 @@ class Eynollah_ocr: has_textline = False for child_textregion in nn: + # FIXME: should remove Word level, if it already exists if child_textregion.tag.endswith("TextLine"): is_textline_text = False @@ -754,6 +673,7 @@ class Eynollah_ocr: indexer_textregion = indexer_textregion + 1 ET.register_namespace("",page_ns) + self.logger.info("output filename: '%s'", out_file_ocr) page_tree.write(out_file_ocr, xml_declaration=True, method='xml', encoding="utf-8", default_namespace=None) def run( diff --git a/src/eynollah/mb_ro_on_layout.py b/src/eynollah/mb_ro_on_layout.py index b0b5910..6c0477b 100644 --- a/src/eynollah/mb_ro_on_layout.py +++ b/src/eynollah/mb_ro_on_layout.py @@ -17,9 +17,7 @@ import cv2 import numpy as np import statistics -os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 -import tensorflow as tf - +from .eynollah import Eynollah from .model_zoo import EynollahModelZoo from .utils.resize import resize_image from .utils.contour import ( @@ -33,23 +31,27 @@ DPI_THRESHOLD = 298 KERNEL = np.ones((5, 5), np.uint8) -class machine_based_reading_order_on_layout: +class Reorder(Eynollah): def __init__( - self, - *, - model_zoo: EynollahModelZoo, - logger : Optional[logging.Logger] = None, + self, + *, + model_zoo: EynollahModelZoo, + logger : Optional[logging.Logger] = None, + device: str = '', ): self.logger = logger or logging.getLogger('eynollah.mbreorder') self.model_zoo = model_zoo - try: - for device in tf.config.list_physical_devices('GPU'): - tf.config.experimental.set_memory_growth(device, True) - except: - self.logger.warning("no GPU device available") - - self.model_zoo.load_models('reading_order') + self.model_zoo.load_model('reading_order') + self.setup_models(device=device) + + def setup_models(self, device=''): + loadable = ['reading_order'] + self.model_zoo.load_models(*loadable, device=device) + for model in loadable: + self.logger.debug("model %s has input shape %s", model, + self.model_zoo.get(model).input_shape) + def read_xml(self, xml_file): tree1 = ET.parse(xml_file, parser = ET.XMLParser(encoding='utf-8')) @@ -675,7 +677,7 @@ class machine_based_reading_order_on_layout: tot_counter += 1 batch.append(j) if tot_counter % inference_bs == 0 or tot_counter == len(ij_list): - y_pr = self.model_zoo.get('reading_order').predict(input_1 , verbose='0') + y_pr = self.model_zoo.get('reading_order').predict(input_1, verbose=0) for jb, j in enumerate(batch): if y_pr[jb][0]>=0.5: post_list.append(j) diff --git a/src/eynollah/model_zoo/.nfs00000002feddea7d00000031 b/src/eynollah/model_zoo/.nfs00000002feddea7d00000031 deleted file mode 100644 index c7dd87d..0000000 Binary files a/src/eynollah/model_zoo/.nfs00000002feddea7d00000031 and /dev/null differ diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 9611388..d5e69a2 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -35,7 +35,7 @@ class EynollahModelZoo: self._overrides = [] if model_overrides: self.override_models(*model_overrides) - self._loaded: Dict[str, Predictor] = {} + self._loaded: Dict[str, Union[Predictor, AnyModel]] = {} @property def model_overrides(self): @@ -70,6 +70,9 @@ class EynollahModelZoo: model_path = Path(self.model_basedir).joinpath(spec.filename) else: model_path = Path(spec.filename) + if model_path.suffix == '.h5' and Path(model_path.stem).exists(): + # prefer SavedModel over HDF5 format if it exists + model_path = Path(model_path.stem) return model_path def load_models( @@ -82,32 +85,50 @@ class EynollahModelZoo: """ ret = {} # cannot use self._loaded here, yet – first spawn all predictors for load_args in all_load_args: + load_kwargs = dict(device=device) if isinstance(load_args, str): - model_category = load_args - load_args = [model_category] + model_category, model_variant = load_args, "" + elif len(load_args) > 2: + # for calls to self.model_path + self.override_models(load_args) + # for calls to Predictor.load_model + model_category, model_variant, model_path = load_args + load_kwargs["model_variant"] = model_variant + load_kwargs["model_path_override"] = model_path else: - model_category = load_args[0] - load_kwargs = {} + model_category, model_variant = load_args + load_kwargs["model_variant"] = model_variant + if model_category.endswith('_resized'): - load_args[0] = model_category[:-8] + model_category = model_category[:-8] load_kwargs["resized"] = True elif model_category.endswith('_patched'): - load_args[0] = model_category[:-8] + model_category = model_category[:-8] load_kwargs["patched"] = True - spec = self.specs.get(model_category, load_args[1] if len(load_args) > 1 else '') - if spec.type in ['Keras'] and spec.category != 'ocr': - ret[model_category] = Predictor(self.logger, self) - ret[model_category].load_model(*load_args, **load_kwargs, device=device) + + if model_category == 'ocr': + model = self._load_ocr_model(variant=model_variant, device=device) + elif model_category == 'num_to_char': + model = self._load_num_to_char() + elif model_category == 'characters': + model = self._load_characters() + elif model_category == 'trocr_processor': + from transformers import TrOCRProcessor + model_path = self.model_path(model_category, model_variant) + model = TrOCRProcessor.from_pretrained(model_path) else: - ret[model_category] = self.load_model(*load_args, **load_kwargs, device=device) + model = Predictor(self.logger, self) + model.load_model(model_category, **load_kwargs) + + ret[model_category] = model self._loaded.update(ret) return self._loaded def load_model( - self, - model_category: str, - model_variant: str = '', - model_path_override: Optional[str] = None, + self, + model_category: str, + model_variant: str = '', + model_path_override: Optional[str] = None, patched: bool = False, resized: bool = False, device: str = '', @@ -121,6 +142,7 @@ class EynollahModelZoo: import tensorflow as tf from tensorflow.keras.models import load_model + from tensorflow.keras.models import Model as KerasModel from ..patch_encoder import ( PatchEncoder, @@ -132,7 +154,7 @@ class EynollahModelZoo: try: gpus = tf.config.list_physical_devices('GPU') if device: - if ',' in device: + if ':' in device: for spec in device.split(','): cat, dev = spec.split(':') if fnmatchcase(model_category, cat): @@ -147,7 +169,24 @@ class EynollahModelZoo: gpus = gpus[:1] # TF will always use first allowable tf.config.set_visible_devices(gpus, 'GPU') for device in gpus: - tf.config.experimental.set_memory_growth(device, True) + # tf.config.experimental.set_memory_growth(device, True) + # dynamic growth never frees memory (to avoid fragmentation), + # so the VRAM requirements end up much larger than feasible + # (for small GPUs); so try hard (calibrated) limits instead: + tf.config.set_logical_device_configuration( + device, + [tf.config.LogicalDeviceConfiguration(memory_limit={ + "binarization": 868, # due to bs 5 + "enhancement": 980, # due to bs 3 + "col_classifier": 210, + "page": 618, + "textline": 1680, # 954 for bs 1 + "region_1_2": 1580, + "region_fl_np": 1756, + "table": 1818, + "reading_order": 632, + "ocr": 850, + }[model_category])]) vendor_name = ( tf.config.experimental.get_device_details(device) .get('device_name', 'unknown')) @@ -166,65 +205,76 @@ class EynollahModelZoo: if model_path_override: self.override_models((model_category, model_variant, model_path_override)) model_path = self.model_path(model_category, model_variant) - if model_path.suffix == '.h5' and Path(model_path.stem).exists(): - # prefer SavedModel over HDF5 format if it exists - model_path = Path(model_path.stem) - if model_category == 'ocr': - model = self._load_ocr_model(variant=model_variant) - elif model_category == 'num_to_char': - model = self._load_num_to_char() - elif model_category == 'characters': - model = self._load_characters() - elif model_category == 'trocr_processor': - from transformers import TrOCRProcessor - model = TrOCRProcessor.from_pretrained(model_path) - else: - try: - # avoid wasting VRAM on non-transformer models - model = load_model(model_path, compile=False) - except Exception as e: - self.logger.error(e) - model = load_model( - model_path, compile=False, - custom_objects=dict(PatchEncoder=PatchEncoder, - Patches=Patches)) - model._name = model_category - if resized: - model = wrap_layout_model_resized(model) - model._name = model_category + '_resized' - elif patched: - model = wrap_layout_model_patched(model) - model._name = model_category + '_patched' - else: - model.jit_compile = True + try: + if model_path.is_dir() and not (model_path / "keras_metadata.pb").exists(): + # short-cut to avoid warning for exported models + raise ValueError() + model = load_model(model_path, compile=False) model.make_predict_function() + except (AttributeError, ValueError): + model = tf.saved_model.load(model_path) + model.predict_on_batch = model.serve + model.input_shape = tuple(model.signatures.get('serving_default').inputs[0].shape) + model._name = model_category + if resized: + model = wrap_layout_model_resized(model) + model._name = model_category + '_resized' + elif patched: + model = wrap_layout_model_patched(model) + model._name = model_category + '_patched' + else: + # increases required VRAM, does not always work + # (depending on CUDA/libcudnn/TF version): + #model.jit_compile = True + pass + + if model_category == 'ocr': + model = KerasModel( + model.get_layer(name="image").input, # type: ignore + model.get_layer(name="dense2").output, # type: ignore + ) + return model - def get(self, model_category: str) -> Predictor: + def get(self, model_category: str) -> Union[Predictor, AnyModel]: if model_category not in self._loaded: raise ValueError(f'Model "{model_category}" not previously loaded with "load_model(..)"') return self._loaded[model_category] - def _load_ocr_model(self, variant: str) -> AnyModel: + def _load_ocr_model(self, variant: str, device: str = "") -> AnyModel: """ Load OCR model """ - from tensorflow.keras.models import Model as KerasModel - from tensorflow.keras.models import load_model - - ocr_model_dir = self.model_path('ocr', variant) + model_dir = self.model_path('ocr', variant) if variant == 'tr': from transformers import VisionEncoderDecoderModel - ret = VisionEncoderDecoderModel.from_pretrained(ocr_model_dir) - assert isinstance(ret, VisionEncoderDecoderModel) - return ret - else: - ocr_model = load_model(ocr_model_dir, compile=False) - assert isinstance(ocr_model, KerasModel) - return KerasModel( - ocr_model.get_layer(name="image").input, # type: ignore - ocr_model.get_layer(name="dense2").output, # type: ignore - ) + import torch + model = VisionEncoderDecoderModel.from_pretrained(model_dir) + assert isinstance(model, VisionEncoderDecoderModel) + device0 = torch.device('cpu') + if not device and torch.cuda.is_available(): + device = 'GPU' # try + if device and ':' in device: + for spec in device.split(','): + cat, dev = spec.split(':') + if fnmatchcase('ocr', cat): + device = dev + break + if device and device.startswith('GPU'): + try: + device0 = torch.device('cuda', int(device[3:] or 0)) + name = torch.cuda.get_device_name(device0) + self.logger.info("using GPU %s (%s) for model ocr:tr", device0, name) + except: + self.logger.exception("cannot configure GPU device") + device0 = torch.device('cpu') + if device0.type == 'cuda': + model.to(device0) + else: + self.logger.warning("no GPU device available") + return model + + return self.load_model('ocr', model_variant=variant, device=device) def _load_characters(self) -> List[str]: """ @@ -237,6 +287,10 @@ class EynollahModelZoo: """ Load decoder for OCR """ + os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 + from ocrd_utils import tf_disable_interactive_logs + tf_disable_interactive_logs() + from tensorflow.keras.layers import StringLookup characters = self._load_characters() @@ -277,5 +331,5 @@ class EynollahModelZoo: """ if hasattr(self, '_loaded') and getattr(self, '_loaded'): for needle in list(self._loaded.keys()): - self._loaded[needle].shutdown() - del self._loaded[needle] + if isinstance(self._loaded[needle], Predictor): + self._loaded[needle].shutdown() diff --git a/src/eynollah/ocrd_cli.py b/src/eynollah/ocrd_cli.py index acd8d4e..effecb2 100644 --- a/src/eynollah/ocrd_cli.py +++ b/src/eynollah/ocrd_cli.py @@ -1,10 +1,8 @@ -# NOTE: For predictable order of imports of torch/shapely/tensorflow -# this must be the first import of the CLI! -from .eynollah_imports import imported_libs -from .processor import EynollahProcessor from click import command from ocrd.decorators import ocrd_cli_options, ocrd_cli_wrap_processor +from .processor import EynollahProcessor + @command() @ocrd_cli_options def main(*args, **kwargs): diff --git a/src/eynollah/predictor.py b/src/eynollah/predictor.py index e1159e7..3c6890e 100644 --- a/src/eynollah/predictor.py +++ b/src/eynollah/predictor.py @@ -194,17 +194,18 @@ class Predictor(mp.context.SpawnProcess): def shutdown(self): # do not terminate from forked processor instances - if mp.parent_process() is None: + if not hasattr(self, 'model'): self.stopped.set() + self.join() self.taskq.close() self.taskq.cancel_join_thread() self.resultq.close() self.resultq.cancel_join_thread() self.logq.close() - self.terminate() + #self.terminate() else: del self.model def __del__(self): - #self.logger.debug(f"deinit of {self} in {mp.current_process().name}") + #self.logger.debug(f"deinit of {self.name} in {mp.current_process().name}") self.shutdown() diff --git a/src/eynollah/training/models.py b/src/eynollah/training/models.py index 3494249..f700d14 100644 --- a/src/eynollah/training/models.py +++ b/src/eynollah/training/models.py @@ -309,11 +309,10 @@ def transformer_block(img, # Skip connection 2. encoded_patches = Add()([x3, x2]) - encoded_patches = tf.reshape(encoded_patches, - [-1, - img.shape[1], - img.shape[2], - projection_dim // (patchsize_x * patchsize_y)]) + encoded_patches = Reshape(target_shape=(img.shape[1], + img.shape[2], + projection_dim // (patchsize_x * patchsize_y)), + name="reshape_patches")(encoded_patches) return encoded_patches def vit_resnet50_unet(num_patches, diff --git a/src/eynollah/training/reload-models-v0.8.mk b/src/eynollah/training/reload-models-v0.8.mk index b7a38dd..07be7cf 100644 --- a/src/eynollah/training/reload-models-v0.8.mk +++ b/src/eynollah/training/reload-models-v0.8.mk @@ -26,16 +26,17 @@ RELOADABLE_MODELS = \ all: $(RELOADABLE_MODELS) $(MODELS_DST)/%: $(MODELS_SRC)/% - mkdir -p $@ test -e $&1 | tee $(notdir $<).log - cp $ 0: args = ['-l', 'DEBUG'] + args caplog.set_level(logging.INFO) diff --git a/tests/cli_tests/test_layout.py b/tests/cli_tests/test_layout.py index 7cbe013..503aeac 100644 --- a/tests/cli_tests/test_layout.py +++ b/tests/cli_tests/test_layout.py @@ -6,11 +6,12 @@ from ocrd_models.constants import NAMESPACES as NS "options", [ [], # defaults - #["--allow_scaling", "--curved-line"], - ["--allow_scaling", "--curved-line", "--full-layout"], - ["--allow_scaling", "--curved-line", "--full-layout", "--reading_order_machine_based"], + #["--curved-line"], + ["--curved-line", "--full-layout"], + ["--curved-line", "--full-layout", "--reading_order_machine_based"], # -ep ... - # -eoi ... + # --input_binary + # --ignore_page_extraction # --skip_layout_and_reading_order ], ids=str) def test_run_eynollah_layout_filename( diff --git a/tests/cli_tests/test_ocr.py b/tests/cli_tests/test_ocr.py index 6bf3080..cf34e06 100644 --- a/tests/cli_tests/test_ocr.py +++ b/tests/cli_tests/test_ocr.py @@ -30,7 +30,7 @@ def test_run_eynollah_ocr_filename( '-o', str(outfile.parent), ] + options, [ - # FIXME: ocr has no logging! + 'output filename:' ] ) assert outfile.exists() @@ -57,7 +57,7 @@ def test_run_eynollah_ocr_directory( '-o', str(outdir), ], [ - # FIXME: ocr has no logging! + 'output filename:' ] ) assert len(list(outdir.iterdir())) == 2 diff --git a/tests/test_model_zoo.py b/tests/test_model_zoo.py index 9d37431..341bc21 100644 --- a/tests/test_model_zoo.py +++ b/tests/test_model_zoo.py @@ -6,10 +6,10 @@ def test_trocr1( model_zoo = EynollahModelZoo(model_dir) try: from transformers import TrOCRProcessor, VisionEncoderDecoderModel - model_zoo.load_models('trocr_processor') + model_zoo.load_models('trocr_processor', + ('ocr', 'tr')) proc = model_zoo.get('trocr_processor') assert isinstance(proc, TrOCRProcessor) - model_zoo.load_models(['ocr', 'tr']) model = model_zoo.get('ocr') assert isinstance(model, VisionEncoderDecoderModel) except ImportError: