wrap layout models for prediction (image resize or tiling) all in TF

(to avoid back and forth between CPU and GPU memory when looping
 over image patches)

- `patch_encoder`: define `Model` subclasses which take an existing
  (layout segmentation) model in the constructor, and define a new
  `call()` using the existing model in a GPU-only `tf.function`:
  * `wrap_layout_model_resized`: just `tf.image.resize()` from
    input image to model size, then predict, then resize back
  * `wrap_layout_model_patched`: ditto if smaller than model size;
    otherwise use `tf.image.extract_patches` for patching in a
    sliding-window approach, then predict patches one by one, then
    `tf.scatter_nd` to reconstruct to image size
- when compiling `tf.function` graph, make sure to use input signature
  with variable image size, but avoid retracing each new size sample
- in `EynollahModelZoo.load_model` for relevant model types,
  also wrap the loaded model
  * by `wrap_layout_model_resized` under model name + `_resized`
  * by `wrap_layout_model_patched` under model name + `_patched`
- introduce `do_prediction_new_concept_autosize`,
  replacing `do_prediction/_new_concept`,
  but using passed model's `predict` directly without
  resizing or tiling to model size
- instead of `do_prediction/_new_concept(True, ...)`,
  now call `do_prediction_new_concept_autosize`,
  but with `_patched` appended to model name
- instead of `do_prediction/_new_concept(False, ...)`,
  now call `do_prediction_new_concept_autosize`,
  but with `_resized` appended to model name
This commit is contained in:
Robert Sachunsky 2026-03-07 03:33:44 +01:00
parent f33fd57da8
commit 338c4a0edf
3 changed files with 179 additions and 25 deletions

View file

@ -914,8 +914,34 @@ class Eynollah:
gc.collect()
return prediction, confidence
def do_prediction_new_concept_autosize(
self, img, model,
thresholding_for_heading=False,
thresholding_for_artificial_class=False,
threshold_art_class=0.1,
artificial_class=4,
):
self.logger.debug("enter do_prediction_new_concept (%s)", model.name)
img = img / 255.0
prediction = model.predict(img[np.newaxis]).numpy()[0]
confidence = prediction[:, :, 1]
segmentation = np.argmax(prediction, axis=2).astype(np.uint8)
if thresholding_for_artificial_class:
seg_mask_label(segmentation,
prediction[:, :, artificial_class] >= threshold_art_class,
label=artificial_class,
only=True,
skeletonize=True,
dilate=3)
if thresholding_for_heading:
seg_mask_label(segmentation,
prediction[:, :, 2] >= 0.2,
label=2)
gc.collect()
return segmentation, confidence
def extract_page(self):
self.logger.debug("enter extract_page")
@ -990,7 +1016,9 @@ class Eynollah:
self.logger.debug("enter extract_text_regions")
img_height_h = img.shape[0]
img_width_h = img.shape[1]
model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np")
#model_name = "region_fl" if patches else "region_fl_np"
model_name = "region_fl_patched" if patches else "region_fl_np_resized"
model_region = self.model_zoo.get(model_name)
thresholding_for_heading = True
img = otsu_copy_binary(img).astype(np.uint8)
@ -1010,10 +1038,9 @@ class Eynollah:
else:
img = resize_image(img, int(img_height_h * 2500 / float(img_width_h)), 2500).astype(np.uint8)
prediction_regions = self.do_prediction(patches, img, model_region,
marginal_of_patch_percent=0.1,
n_batch_inference=3,
thresholding_for_heading=thresholding_for_heading)
prediction_regions, _ = self.do_prediction_new_concept_autosize(
img, model_region,
thresholding_for_heading=thresholding_for_heading)
prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h)
self.logger.debug("exit extract_text_regions")
return prediction_regions
@ -1162,10 +1189,11 @@ class Eynollah:
def textline_contours(self, img, use_patches, num_col_classifier=None):
self.logger.debug('enter textline_contours')
prediction_textline = self.do_prediction(use_patches, img, self.model_zoo.get("textline"),
marginal_of_patch_percent=0.15,
n_batch_inference=3,
threshold_art_class=self.threshold_art_class_textline)
prediction_textline, _ = self.do_prediction_new_concept_autosize(
img, self.model_zoo.get("textline_patched" if use_patches else "textline_resized"),
artificial_class=2,
thresholding_for_artificial_class=True,
threshold_art_class=self.threshold_art_class_textline)
#prediction_textline_longshot = self.do_prediction(False, img, self.model_zoo.get("textline"))
@ -1242,17 +1270,19 @@ class Eynollah:
if self.image_org.shape[0]/self.image_org.shape[1] > 2.5:
self.logger.debug("resized to %dx%d for %d cols",
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=1,
thresholding_for_artificial_class=True,
threshold_art_class=self.threshold_art_class_layout)
prediction_regions_org, confidence_matrix = \
self.do_prediction_new_concept_autosize(
img_resized, self.model_zoo.get("region_1_2_patched"),
thresholding_for_artificial_class=True,
threshold_art_class=self.threshold_art_class_layout)
else:
prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept(
False, self.image_page_org_size, self.model_zoo.get("region_1_2"), n_batch_inference=1,
thresholding_for_artificial_class=True,
threshold_art_class=self.threshold_art_class_layout)
prediction_regions_page, confidence_matrix_page = \
self.do_prediction_new_concept_autosize(
self.image_page_org_size, self.model_zoo.get("region_1_2_resized"),
thresholding_for_artificial_class=True,
threshold_art_class=self.threshold_art_class_layout)
ys = slice(*self.page_coord[0:2])
xs = slice(*self.page_coord[2:4])
prediction_regions_org[ys, xs] = prediction_regions_page
@ -1263,10 +1293,11 @@ class Eynollah:
img_resized = resize_image(img_bin, int(new_h * img_bin.shape[0] /img_bin.shape[1]), new_h)
self.logger.debug("resized to %dx%d (new_h=%d) for %d cols",
img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier)
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=2,
thresholding_for_artificial_class=True,
threshold_art_class=self.threshold_art_class_layout)
prediction_regions_org, confidence_matrix = \
self.do_prediction_new_concept_autosize(
img_resized, self.model_zoo.get("region_1_2_patched"),
thresholding_for_artificial_class=True,
threshold_art_class=self.threshold_art_class_layout)
###prediction_regions_org = self.do_prediction(True, img_bin, self.model_zoo.get_model("region"),
###n_batch_inference=3,
###thresholding_for_some_classes=True)
@ -1664,8 +1695,7 @@ class Eynollah:
img_org = np.copy(img)
img_height_h = img_org.shape[0]
img_width_h = img_org.shape[1]
patches = False
prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_zoo.get("table"))
prediction_table, _ = self.do_prediction_new_concept_autosize(img, self.model_zoo.get("table_resized"))
prediction_table = prediction_table.astype(np.int16)
return prediction_table

View file

@ -14,7 +14,12 @@ from tensorflow.keras.models import Model as KerasModel
from tensorflow.keras.models import load_model
from tabulate import tabulate
from ..patch_encoder import PatchEncoder, Patches
from ..patch_encoder import (
PatchEncoder,
Patches,
wrap_layout_model_patched,
wrap_layout_model_resized,
)
from .specs import EynollahModelSpecSet
from .default_specs import DEFAULT_MODEL_SPECS
from .types import AnyModel, T
@ -125,7 +130,12 @@ class EynollahModelZoo:
model = load_model(
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
)
model._name = model_category
self._loaded[model_category] = model
if model_category in ['region_1_2', 'table', 'region_fl_np']:
self._loaded[model_category + '_resized'] = wrap_layout_model_resized(model)
if model_category in ['region_1_2', 'textline']:
self._loaded[model_category + '_patched'] = wrap_layout_model_patched(model)
return model # type: ignore
def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T:

View file

@ -1,7 +1,7 @@
import os
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import layers, models
class PatchEncoder(layers.Layer):
@ -44,3 +44,117 @@ class Patches(layers.Layer):
return dict(patch_size_x=self.patch_size_x,
patch_size_y=self.patch_size_y,
**super().get_config())
class wrap_layout_model_resized(models.Model):
"""
replacement for layout model using resizing to model width/height and back
(accepts arbitrary width/height input [B, H, W, 3], returns same size segmentation [B, H, W, C])
"""
def __init__(self, model):
super().__init__(name=model.name + '_resized')
self.model = model
self.height = model.layers[-1].output_shape[1]
self.width = model.layers[-1].output_shape[2]
def call(self, img, training=False):
height = tf.shape(img)[1]
width = tf.shape(img)[2]
img_resized = tf.image.resize(img,
(self.height, self.width),
antialias=True)
pred_resized = self.model(img_resized)
pred = tf.image.resize(pred_resized,
(height, width))
return pred
@tf.function(reduce_retracing=True,
#jit_compile=True, (ScaleAndTranslate is not supported by XLA)
input_signature=[tf.TensorSpec([1, None, None, 3],
dtype=tf.float32)])
def predict(self, x):
return self(x)
class wrap_layout_model_patched(models.Model):
"""
replacement for layout model using sliding window for patches
(accepts arbitrary width/height input [B, H, W, 3], returns same size segmentation [B, H, W, C])
"""
def __init__(self, model):
super().__init__(name=model.name + '_patched')
self.model = model
self.height = model.layers[-1].output_shape[1]
self.width = model.layers[-1].output_shape[2]
self.classes = model.layers[-1].output_shape[3]
# equivalent of marginal_of_patch_percent=0.1 ...
self.stride_x = int(self.width * (1 - 0.1))
self.stride_y = int(self.height * (1 - 0.1))
offset_height = (self.height - self.stride_y) // 2
offset_width = (self.width - self.stride_x) // 2
window = tf.image.pad_to_bounding_box(
tf.ones((self.stride_y, self.stride_x, 1), dtype=tf.int32),
offset_height, offset_width,
self.height, self.width)
self.window = tf.expand_dims(window, axis=0)
def call(self, img, training=False):
height = tf.shape(img)[1]
width = tf.shape(img)[2]
if (height < self.height or
width < self.width):
img_resized = tf.image.resize(img,
(self.height, self.width),
antialias=True)
pred_resized = self.model(img_resized)
pred = tf.image.resize(pred_resized,
(height, width))
return pred
img_patches = tf.image.extract_patches(
images=img,
sizes=[1, self.height, self.width, 1],
strides=[1, self.stride_y, self.stride_x, 1],
rates=[1, 1, 1, 1],
padding='SAME')
img_patches = tf.squeeze(img_patches)
new_shape = (-1, self.height, self.width, 3)
img_patches = tf.reshape(img_patches, shape=new_shape)
# may be too large:
#pred_patches = self.model(img_patches)
# so rebatch to fit in memory:
img_patches = tf.expand_dims(img_patches, 1)
pred_patches = tf.map_fn(self.model, img_patches,
parallel_iterations=1,
infer_shape=False)
pred_patches = tf.squeeze(pred_patches, 1)
# calculate corresponding indexes for reconstruction
x = tf.range(width)
y = tf.range(height)
x, y = tf.meshgrid(x, y)
indices = tf.stack([y, x], axis=-1)
indices_patches = tf.image.extract_patches(
images=tf.expand_dims(indices, axis=0),
sizes=[1, self.height, self.width, 1],
strides=[1, self.stride_y, self.stride_x, 1],
rates=[1, 1, 1, 1],
padding='SAME')
indices_patches = tf.squeeze(indices_patches)
indices_patches = tf.reshape(indices_patches, shape=new_shape[:-1] + (2,))
# use margins for sliding window approach
indices_patches = indices_patches * self.window
pred = tf.scatter_nd(
indices_patches,
pred_patches,
(height, width, self.classes))
pred = tf.expand_dims(pred, axis=0)
return pred
@tf.function(reduce_retracing=True,
#jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA)
input_signature=[tf.TensorSpec([1, None, None, 3],
dtype=tf.float32)])
def predict(self, x):
return self(x)