From 338c4a0edff998929626df637ce8f77f59af1f2d Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Sat, 7 Mar 2026 03:33:44 +0100 Subject: [PATCH] wrap layout models for prediction (image resize or tiling) all in TF (to avoid back and forth between CPU and GPU memory when looping over image patches) - `patch_encoder`: define `Model` subclasses which take an existing (layout segmentation) model in the constructor, and define a new `call()` using the existing model in a GPU-only `tf.function`: * `wrap_layout_model_resized`: just `tf.image.resize()` from input image to model size, then predict, then resize back * `wrap_layout_model_patched`: ditto if smaller than model size; otherwise use `tf.image.extract_patches` for patching in a sliding-window approach, then predict patches one by one, then `tf.scatter_nd` to reconstruct to image size - when compiling `tf.function` graph, make sure to use input signature with variable image size, but avoid retracing each new size sample - in `EynollahModelZoo.load_model` for relevant model types, also wrap the loaded model * by `wrap_layout_model_resized` under model name + `_resized` * by `wrap_layout_model_patched` under model name + `_patched` - introduce `do_prediction_new_concept_autosize`, replacing `do_prediction/_new_concept`, but using passed model's `predict` directly without resizing or tiling to model size - instead of `do_prediction/_new_concept(True, ...)`, now call `do_prediction_new_concept_autosize`, but with `_patched` appended to model name - instead of `do_prediction/_new_concept(False, ...)`, now call `do_prediction_new_concept_autosize`, but with `_resized` appended to model name --- src/eynollah/eynollah.py | 76 ++++++++++++------ src/eynollah/model_zoo/model_zoo.py | 12 ++- src/eynollah/patch_encoder.py | 116 +++++++++++++++++++++++++++- 3 files changed, 179 insertions(+), 25 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index dcdc642..d536cc8 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -914,8 +914,34 @@ class Eynollah: gc.collect() return prediction, confidence + def do_prediction_new_concept_autosize( + self, img, model, + thresholding_for_heading=False, + thresholding_for_artificial_class=False, + threshold_art_class=0.1, + artificial_class=4, + ): + self.logger.debug("enter do_prediction_new_concept (%s)", model.name) + img = img / 255.0 + + prediction = model.predict(img[np.newaxis]).numpy()[0] + confidence = prediction[:, :, 1] + segmentation = np.argmax(prediction, axis=2).astype(np.uint8) + + if thresholding_for_artificial_class: + seg_mask_label(segmentation, + prediction[:, :, artificial_class] >= threshold_art_class, + label=artificial_class, + only=True, + skeletonize=True, + dilate=3) + if thresholding_for_heading: + seg_mask_label(segmentation, + prediction[:, :, 2] >= 0.2, + label=2) gc.collect() + return segmentation, confidence def extract_page(self): self.logger.debug("enter extract_page") @@ -990,7 +1016,9 @@ class Eynollah: self.logger.debug("enter extract_text_regions") img_height_h = img.shape[0] img_width_h = img.shape[1] - model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np") + #model_name = "region_fl" if patches else "region_fl_np" + model_name = "region_fl_patched" if patches else "region_fl_np_resized" + model_region = self.model_zoo.get(model_name) thresholding_for_heading = True img = otsu_copy_binary(img).astype(np.uint8) @@ -1010,10 +1038,9 @@ class Eynollah: else: img = resize_image(img, int(img_height_h * 2500 / float(img_width_h)), 2500).astype(np.uint8) - prediction_regions = self.do_prediction(patches, img, model_region, - marginal_of_patch_percent=0.1, - n_batch_inference=3, - thresholding_for_heading=thresholding_for_heading) + prediction_regions, _ = self.do_prediction_new_concept_autosize( + img, model_region, + thresholding_for_heading=thresholding_for_heading) prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h) self.logger.debug("exit extract_text_regions") return prediction_regions @@ -1162,10 +1189,11 @@ class Eynollah: def textline_contours(self, img, use_patches, num_col_classifier=None): self.logger.debug('enter textline_contours') - prediction_textline = self.do_prediction(use_patches, img, self.model_zoo.get("textline"), - marginal_of_patch_percent=0.15, - n_batch_inference=3, - threshold_art_class=self.threshold_art_class_textline) + prediction_textline, _ = self.do_prediction_new_concept_autosize( + img, self.model_zoo.get("textline_patched" if use_patches else "textline_resized"), + artificial_class=2, + thresholding_for_artificial_class=True, + threshold_art_class=self.threshold_art_class_textline) #prediction_textline_longshot = self.do_prediction(False, img, self.model_zoo.get("textline")) @@ -1242,17 +1270,19 @@ class Eynollah: if self.image_org.shape[0]/self.image_org.shape[1] > 2.5: self.logger.debug("resized to %dx%d for %d cols", img_resized.shape[1], img_resized.shape[0], num_col_classifier) - prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=1, - thresholding_for_artificial_class=True, - threshold_art_class=self.threshold_art_class_layout) + prediction_regions_org, confidence_matrix = \ + self.do_prediction_new_concept_autosize( + img_resized, self.model_zoo.get("region_1_2_patched"), + thresholding_for_artificial_class=True, + threshold_art_class=self.threshold_art_class_layout) else: prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1])) confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1])) - prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept( - False, self.image_page_org_size, self.model_zoo.get("region_1_2"), n_batch_inference=1, - thresholding_for_artificial_class=True, - threshold_art_class=self.threshold_art_class_layout) + prediction_regions_page, confidence_matrix_page = \ + self.do_prediction_new_concept_autosize( + self.image_page_org_size, self.model_zoo.get("region_1_2_resized"), + thresholding_for_artificial_class=True, + threshold_art_class=self.threshold_art_class_layout) ys = slice(*self.page_coord[0:2]) xs = slice(*self.page_coord[2:4]) prediction_regions_org[ys, xs] = prediction_regions_page @@ -1263,10 +1293,11 @@ class Eynollah: img_resized = resize_image(img_bin, int(new_h * img_bin.shape[0] /img_bin.shape[1]), new_h) self.logger.debug("resized to %dx%d (new_h=%d) for %d cols", img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier) - prediction_regions_org, confidence_matrix = self.do_prediction_new_concept( - True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=2, - thresholding_for_artificial_class=True, - threshold_art_class=self.threshold_art_class_layout) + prediction_regions_org, confidence_matrix = \ + self.do_prediction_new_concept_autosize( + img_resized, self.model_zoo.get("region_1_2_patched"), + thresholding_for_artificial_class=True, + threshold_art_class=self.threshold_art_class_layout) ###prediction_regions_org = self.do_prediction(True, img_bin, self.model_zoo.get_model("region"), ###n_batch_inference=3, ###thresholding_for_some_classes=True) @@ -1664,8 +1695,7 @@ class Eynollah: img_org = np.copy(img) img_height_h = img_org.shape[0] img_width_h = img_org.shape[1] - patches = False - prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_zoo.get("table")) + prediction_table, _ = self.do_prediction_new_concept_autosize(img, self.model_zoo.get("table_resized")) prediction_table = prediction_table.astype(np.int16) return prediction_table diff --git a/src/eynollah/model_zoo/model_zoo.py b/src/eynollah/model_zoo/model_zoo.py index 2147b0e..ca5de05 100644 --- a/src/eynollah/model_zoo/model_zoo.py +++ b/src/eynollah/model_zoo/model_zoo.py @@ -14,7 +14,12 @@ from tensorflow.keras.models import Model as KerasModel from tensorflow.keras.models import load_model from tabulate import tabulate -from ..patch_encoder import PatchEncoder, Patches +from ..patch_encoder import ( + PatchEncoder, + Patches, + wrap_layout_model_patched, + wrap_layout_model_resized, +) from .specs import EynollahModelSpecSet from .default_specs import DEFAULT_MODEL_SPECS from .types import AnyModel, T @@ -125,7 +130,12 @@ class EynollahModelZoo: model = load_model( model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches} ) + model._name = model_category self._loaded[model_category] = model + if model_category in ['region_1_2', 'table', 'region_fl_np']: + self._loaded[model_category + '_resized'] = wrap_layout_model_resized(model) + if model_category in ['region_1_2', 'textline']: + self._loaded[model_category + '_patched'] = wrap_layout_model_patched(model) return model # type: ignore def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T: diff --git a/src/eynollah/patch_encoder.py b/src/eynollah/patch_encoder.py index 07b843d..20b71d6 100644 --- a/src/eynollah/patch_encoder.py +++ b/src/eynollah/patch_encoder.py @@ -1,7 +1,7 @@ import os os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15 import tensorflow as tf -from tensorflow.keras import layers +from tensorflow.keras import layers, models class PatchEncoder(layers.Layer): @@ -44,3 +44,117 @@ class Patches(layers.Layer): return dict(patch_size_x=self.patch_size_x, patch_size_y=self.patch_size_y, **super().get_config()) + +class wrap_layout_model_resized(models.Model): + """ + replacement for layout model using resizing to model width/height and back + + (accepts arbitrary width/height input [B, H, W, 3], returns same size segmentation [B, H, W, C]) + """ + def __init__(self, model): + super().__init__(name=model.name + '_resized') + self.model = model + self.height = model.layers[-1].output_shape[1] + self.width = model.layers[-1].output_shape[2] + + def call(self, img, training=False): + height = tf.shape(img)[1] + width = tf.shape(img)[2] + img_resized = tf.image.resize(img, + (self.height, self.width), + antialias=True) + pred_resized = self.model(img_resized) + pred = tf.image.resize(pred_resized, + (height, width)) + return pred + + @tf.function(reduce_retracing=True, + #jit_compile=True, (ScaleAndTranslate is not supported by XLA) + input_signature=[tf.TensorSpec([1, None, None, 3], + dtype=tf.float32)]) + def predict(self, x): + return self(x) + +class wrap_layout_model_patched(models.Model): + """ + replacement for layout model using sliding window for patches + + (accepts arbitrary width/height input [B, H, W, 3], returns same size segmentation [B, H, W, C]) + """ + def __init__(self, model): + super().__init__(name=model.name + '_patched') + self.model = model + self.height = model.layers[-1].output_shape[1] + self.width = model.layers[-1].output_shape[2] + self.classes = model.layers[-1].output_shape[3] + # equivalent of marginal_of_patch_percent=0.1 ... + self.stride_x = int(self.width * (1 - 0.1)) + self.stride_y = int(self.height * (1 - 0.1)) + offset_height = (self.height - self.stride_y) // 2 + offset_width = (self.width - self.stride_x) // 2 + window = tf.image.pad_to_bounding_box( + tf.ones((self.stride_y, self.stride_x, 1), dtype=tf.int32), + offset_height, offset_width, + self.height, self.width) + self.window = tf.expand_dims(window, axis=0) + + def call(self, img, training=False): + height = tf.shape(img)[1] + width = tf.shape(img)[2] + if (height < self.height or + width < self.width): + img_resized = tf.image.resize(img, + (self.height, self.width), + antialias=True) + pred_resized = self.model(img_resized) + pred = tf.image.resize(pred_resized, + (height, width)) + return pred + + img_patches = tf.image.extract_patches( + images=img, + sizes=[1, self.height, self.width, 1], + strides=[1, self.stride_y, self.stride_x, 1], + rates=[1, 1, 1, 1], + padding='SAME') + img_patches = tf.squeeze(img_patches) + new_shape = (-1, self.height, self.width, 3) + img_patches = tf.reshape(img_patches, shape=new_shape) + # may be too large: + #pred_patches = self.model(img_patches) + # so rebatch to fit in memory: + img_patches = tf.expand_dims(img_patches, 1) + pred_patches = tf.map_fn(self.model, img_patches, + parallel_iterations=1, + infer_shape=False) + pred_patches = tf.squeeze(pred_patches, 1) + # calculate corresponding indexes for reconstruction + x = tf.range(width) + y = tf.range(height) + x, y = tf.meshgrid(x, y) + indices = tf.stack([y, x], axis=-1) + indices_patches = tf.image.extract_patches( + images=tf.expand_dims(indices, axis=0), + sizes=[1, self.height, self.width, 1], + strides=[1, self.stride_y, self.stride_x, 1], + rates=[1, 1, 1, 1], + padding='SAME') + indices_patches = tf.squeeze(indices_patches) + indices_patches = tf.reshape(indices_patches, shape=new_shape[:-1] + (2,)) + + # use margins for sliding window approach + indices_patches = indices_patches * self.window + + pred = tf.scatter_nd( + indices_patches, + pred_patches, + (height, width, self.classes)) + pred = tf.expand_dims(pred, axis=0) + return pred + + @tf.function(reduce_retracing=True, + #jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA) + input_signature=[tf.TensorSpec([1, None, None, 3], + dtype=tf.float32)]) + def predict(self, x): + return self(x)