mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-24 16:12:03 +01:00
wrap layout models for prediction (image resize or tiling) all in TF
(to avoid back and forth between CPU and GPU memory when looping
over image patches)
- `patch_encoder`: define `Model` subclasses which take an existing
(layout segmentation) model in the constructor, and define a new
`call()` using the existing model in a GPU-only `tf.function`:
* `wrap_layout_model_resized`: just `tf.image.resize()` from
input image to model size, then predict, then resize back
* `wrap_layout_model_patched`: ditto if smaller than model size;
otherwise use `tf.image.extract_patches` for patching in a
sliding-window approach, then predict patches one by one, then
`tf.scatter_nd` to reconstruct to image size
- when compiling `tf.function` graph, make sure to use input signature
with variable image size, but avoid retracing each new size sample
- in `EynollahModelZoo.load_model` for relevant model types,
also wrap the loaded model
* by `wrap_layout_model_resized` under model name + `_resized`
* by `wrap_layout_model_patched` under model name + `_patched`
- introduce `do_prediction_new_concept_autosize`,
replacing `do_prediction/_new_concept`,
but using passed model's `predict` directly without
resizing or tiling to model size
- instead of `do_prediction/_new_concept(True, ...)`,
now call `do_prediction_new_concept_autosize`,
but with `_patched` appended to model name
- instead of `do_prediction/_new_concept(False, ...)`,
now call `do_prediction_new_concept_autosize`,
but with `_resized` appended to model name
This commit is contained in:
parent
f33fd57da8
commit
338c4a0edf
3 changed files with 179 additions and 25 deletions
|
|
@ -914,8 +914,34 @@ class Eynollah:
|
||||||
gc.collect()
|
gc.collect()
|
||||||
return prediction, confidence
|
return prediction, confidence
|
||||||
|
|
||||||
|
def do_prediction_new_concept_autosize(
|
||||||
|
self, img, model,
|
||||||
|
thresholding_for_heading=False,
|
||||||
|
thresholding_for_artificial_class=False,
|
||||||
|
threshold_art_class=0.1,
|
||||||
|
artificial_class=4,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.logger.debug("enter do_prediction_new_concept (%s)", model.name)
|
||||||
|
img = img / 255.0
|
||||||
|
|
||||||
|
prediction = model.predict(img[np.newaxis]).numpy()[0]
|
||||||
|
confidence = prediction[:, :, 1]
|
||||||
|
segmentation = np.argmax(prediction, axis=2).astype(np.uint8)
|
||||||
|
|
||||||
|
if thresholding_for_artificial_class:
|
||||||
|
seg_mask_label(segmentation,
|
||||||
|
prediction[:, :, artificial_class] >= threshold_art_class,
|
||||||
|
label=artificial_class,
|
||||||
|
only=True,
|
||||||
|
skeletonize=True,
|
||||||
|
dilate=3)
|
||||||
|
if thresholding_for_heading:
|
||||||
|
seg_mask_label(segmentation,
|
||||||
|
prediction[:, :, 2] >= 0.2,
|
||||||
|
label=2)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
return segmentation, confidence
|
||||||
|
|
||||||
def extract_page(self):
|
def extract_page(self):
|
||||||
self.logger.debug("enter extract_page")
|
self.logger.debug("enter extract_page")
|
||||||
|
|
@ -990,7 +1016,9 @@ class Eynollah:
|
||||||
self.logger.debug("enter extract_text_regions")
|
self.logger.debug("enter extract_text_regions")
|
||||||
img_height_h = img.shape[0]
|
img_height_h = img.shape[0]
|
||||||
img_width_h = img.shape[1]
|
img_width_h = img.shape[1]
|
||||||
model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np")
|
#model_name = "region_fl" if patches else "region_fl_np"
|
||||||
|
model_name = "region_fl_patched" if patches else "region_fl_np_resized"
|
||||||
|
model_region = self.model_zoo.get(model_name)
|
||||||
|
|
||||||
thresholding_for_heading = True
|
thresholding_for_heading = True
|
||||||
img = otsu_copy_binary(img).astype(np.uint8)
|
img = otsu_copy_binary(img).astype(np.uint8)
|
||||||
|
|
@ -1010,9 +1038,8 @@ class Eynollah:
|
||||||
else:
|
else:
|
||||||
img = resize_image(img, int(img_height_h * 2500 / float(img_width_h)), 2500).astype(np.uint8)
|
img = resize_image(img, int(img_height_h * 2500 / float(img_width_h)), 2500).astype(np.uint8)
|
||||||
|
|
||||||
prediction_regions = self.do_prediction(patches, img, model_region,
|
prediction_regions, _ = self.do_prediction_new_concept_autosize(
|
||||||
marginal_of_patch_percent=0.1,
|
img, model_region,
|
||||||
n_batch_inference=3,
|
|
||||||
thresholding_for_heading=thresholding_for_heading)
|
thresholding_for_heading=thresholding_for_heading)
|
||||||
prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h)
|
prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h)
|
||||||
self.logger.debug("exit extract_text_regions")
|
self.logger.debug("exit extract_text_regions")
|
||||||
|
|
@ -1162,9 +1189,10 @@ class Eynollah:
|
||||||
def textline_contours(self, img, use_patches, num_col_classifier=None):
|
def textline_contours(self, img, use_patches, num_col_classifier=None):
|
||||||
self.logger.debug('enter textline_contours')
|
self.logger.debug('enter textline_contours')
|
||||||
|
|
||||||
prediction_textline = self.do_prediction(use_patches, img, self.model_zoo.get("textline"),
|
prediction_textline, _ = self.do_prediction_new_concept_autosize(
|
||||||
marginal_of_patch_percent=0.15,
|
img, self.model_zoo.get("textline_patched" if use_patches else "textline_resized"),
|
||||||
n_batch_inference=3,
|
artificial_class=2,
|
||||||
|
thresholding_for_artificial_class=True,
|
||||||
threshold_art_class=self.threshold_art_class_textline)
|
threshold_art_class=self.threshold_art_class_textline)
|
||||||
|
|
||||||
#prediction_textline_longshot = self.do_prediction(False, img, self.model_zoo.get("textline"))
|
#prediction_textline_longshot = self.do_prediction(False, img, self.model_zoo.get("textline"))
|
||||||
|
|
@ -1242,15 +1270,17 @@ class Eynollah:
|
||||||
if self.image_org.shape[0]/self.image_org.shape[1] > 2.5:
|
if self.image_org.shape[0]/self.image_org.shape[1] > 2.5:
|
||||||
self.logger.debug("resized to %dx%d for %d cols",
|
self.logger.debug("resized to %dx%d for %d cols",
|
||||||
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
|
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
|
||||||
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
|
prediction_regions_org, confidence_matrix = \
|
||||||
True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=1,
|
self.do_prediction_new_concept_autosize(
|
||||||
|
img_resized, self.model_zoo.get("region_1_2_patched"),
|
||||||
thresholding_for_artificial_class=True,
|
thresholding_for_artificial_class=True,
|
||||||
threshold_art_class=self.threshold_art_class_layout)
|
threshold_art_class=self.threshold_art_class_layout)
|
||||||
else:
|
else:
|
||||||
prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
|
prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
|
||||||
confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
|
confidence_matrix = np.zeros((self.image_org.shape[0], self.image_org.shape[1]))
|
||||||
prediction_regions_page, confidence_matrix_page = self.do_prediction_new_concept(
|
prediction_regions_page, confidence_matrix_page = \
|
||||||
False, self.image_page_org_size, self.model_zoo.get("region_1_2"), n_batch_inference=1,
|
self.do_prediction_new_concept_autosize(
|
||||||
|
self.image_page_org_size, self.model_zoo.get("region_1_2_resized"),
|
||||||
thresholding_for_artificial_class=True,
|
thresholding_for_artificial_class=True,
|
||||||
threshold_art_class=self.threshold_art_class_layout)
|
threshold_art_class=self.threshold_art_class_layout)
|
||||||
ys = slice(*self.page_coord[0:2])
|
ys = slice(*self.page_coord[0:2])
|
||||||
|
|
@ -1263,8 +1293,9 @@ class Eynollah:
|
||||||
img_resized = resize_image(img_bin, int(new_h * img_bin.shape[0] /img_bin.shape[1]), new_h)
|
img_resized = resize_image(img_bin, int(new_h * img_bin.shape[0] /img_bin.shape[1]), new_h)
|
||||||
self.logger.debug("resized to %dx%d (new_h=%d) for %d cols",
|
self.logger.debug("resized to %dx%d (new_h=%d) for %d cols",
|
||||||
img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier)
|
img_resized.shape[1], img_resized.shape[0], new_h, num_col_classifier)
|
||||||
prediction_regions_org, confidence_matrix = self.do_prediction_new_concept(
|
prediction_regions_org, confidence_matrix = \
|
||||||
True, img_resized, self.model_zoo.get("region_1_2"), n_batch_inference=2,
|
self.do_prediction_new_concept_autosize(
|
||||||
|
img_resized, self.model_zoo.get("region_1_2_patched"),
|
||||||
thresholding_for_artificial_class=True,
|
thresholding_for_artificial_class=True,
|
||||||
threshold_art_class=self.threshold_art_class_layout)
|
threshold_art_class=self.threshold_art_class_layout)
|
||||||
###prediction_regions_org = self.do_prediction(True, img_bin, self.model_zoo.get_model("region"),
|
###prediction_regions_org = self.do_prediction(True, img_bin, self.model_zoo.get_model("region"),
|
||||||
|
|
@ -1664,8 +1695,7 @@ class Eynollah:
|
||||||
img_org = np.copy(img)
|
img_org = np.copy(img)
|
||||||
img_height_h = img_org.shape[0]
|
img_height_h = img_org.shape[0]
|
||||||
img_width_h = img_org.shape[1]
|
img_width_h = img_org.shape[1]
|
||||||
patches = False
|
prediction_table, _ = self.do_prediction_new_concept_autosize(img, self.model_zoo.get("table_resized"))
|
||||||
prediction_table, _ = self.do_prediction_new_concept(patches, img, self.model_zoo.get("table"))
|
|
||||||
prediction_table = prediction_table.astype(np.int16)
|
prediction_table = prediction_table.astype(np.int16)
|
||||||
return prediction_table
|
return prediction_table
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,12 @@ from tensorflow.keras.models import Model as KerasModel
|
||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from ..patch_encoder import PatchEncoder, Patches
|
from ..patch_encoder import (
|
||||||
|
PatchEncoder,
|
||||||
|
Patches,
|
||||||
|
wrap_layout_model_patched,
|
||||||
|
wrap_layout_model_resized,
|
||||||
|
)
|
||||||
from .specs import EynollahModelSpecSet
|
from .specs import EynollahModelSpecSet
|
||||||
from .default_specs import DEFAULT_MODEL_SPECS
|
from .default_specs import DEFAULT_MODEL_SPECS
|
||||||
from .types import AnyModel, T
|
from .types import AnyModel, T
|
||||||
|
|
@ -125,7 +130,12 @@ class EynollahModelZoo:
|
||||||
model = load_model(
|
model = load_model(
|
||||||
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
|
model_path, compile=False, custom_objects={"PatchEncoder": PatchEncoder, "Patches": Patches}
|
||||||
)
|
)
|
||||||
|
model._name = model_category
|
||||||
self._loaded[model_category] = model
|
self._loaded[model_category] = model
|
||||||
|
if model_category in ['region_1_2', 'table', 'region_fl_np']:
|
||||||
|
self._loaded[model_category + '_resized'] = wrap_layout_model_resized(model)
|
||||||
|
if model_category in ['region_1_2', 'textline']:
|
||||||
|
self._loaded[model_category + '_patched'] = wrap_layout_model_patched(model)
|
||||||
return model # type: ignore
|
return model # type: ignore
|
||||||
|
|
||||||
def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T:
|
def get(self, model_category: str, model_type: Optional[Type[T]] = None) -> T:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
os.environ['TF_USE_LEGACY_KERAS'] = '1' # avoid Keras 3 after TF 2.15
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers, models
|
||||||
|
|
||||||
class PatchEncoder(layers.Layer):
|
class PatchEncoder(layers.Layer):
|
||||||
|
|
||||||
|
|
@ -44,3 +44,117 @@ class Patches(layers.Layer):
|
||||||
return dict(patch_size_x=self.patch_size_x,
|
return dict(patch_size_x=self.patch_size_x,
|
||||||
patch_size_y=self.patch_size_y,
|
patch_size_y=self.patch_size_y,
|
||||||
**super().get_config())
|
**super().get_config())
|
||||||
|
|
||||||
|
class wrap_layout_model_resized(models.Model):
|
||||||
|
"""
|
||||||
|
replacement for layout model using resizing to model width/height and back
|
||||||
|
|
||||||
|
(accepts arbitrary width/height input [B, H, W, 3], returns same size segmentation [B, H, W, C])
|
||||||
|
"""
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__(name=model.name + '_resized')
|
||||||
|
self.model = model
|
||||||
|
self.height = model.layers[-1].output_shape[1]
|
||||||
|
self.width = model.layers[-1].output_shape[2]
|
||||||
|
|
||||||
|
def call(self, img, training=False):
|
||||||
|
height = tf.shape(img)[1]
|
||||||
|
width = tf.shape(img)[2]
|
||||||
|
img_resized = tf.image.resize(img,
|
||||||
|
(self.height, self.width),
|
||||||
|
antialias=True)
|
||||||
|
pred_resized = self.model(img_resized)
|
||||||
|
pred = tf.image.resize(pred_resized,
|
||||||
|
(height, width))
|
||||||
|
return pred
|
||||||
|
|
||||||
|
@tf.function(reduce_retracing=True,
|
||||||
|
#jit_compile=True, (ScaleAndTranslate is not supported by XLA)
|
||||||
|
input_signature=[tf.TensorSpec([1, None, None, 3],
|
||||||
|
dtype=tf.float32)])
|
||||||
|
def predict(self, x):
|
||||||
|
return self(x)
|
||||||
|
|
||||||
|
class wrap_layout_model_patched(models.Model):
|
||||||
|
"""
|
||||||
|
replacement for layout model using sliding window for patches
|
||||||
|
|
||||||
|
(accepts arbitrary width/height input [B, H, W, 3], returns same size segmentation [B, H, W, C])
|
||||||
|
"""
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__(name=model.name + '_patched')
|
||||||
|
self.model = model
|
||||||
|
self.height = model.layers[-1].output_shape[1]
|
||||||
|
self.width = model.layers[-1].output_shape[2]
|
||||||
|
self.classes = model.layers[-1].output_shape[3]
|
||||||
|
# equivalent of marginal_of_patch_percent=0.1 ...
|
||||||
|
self.stride_x = int(self.width * (1 - 0.1))
|
||||||
|
self.stride_y = int(self.height * (1 - 0.1))
|
||||||
|
offset_height = (self.height - self.stride_y) // 2
|
||||||
|
offset_width = (self.width - self.stride_x) // 2
|
||||||
|
window = tf.image.pad_to_bounding_box(
|
||||||
|
tf.ones((self.stride_y, self.stride_x, 1), dtype=tf.int32),
|
||||||
|
offset_height, offset_width,
|
||||||
|
self.height, self.width)
|
||||||
|
self.window = tf.expand_dims(window, axis=0)
|
||||||
|
|
||||||
|
def call(self, img, training=False):
|
||||||
|
height = tf.shape(img)[1]
|
||||||
|
width = tf.shape(img)[2]
|
||||||
|
if (height < self.height or
|
||||||
|
width < self.width):
|
||||||
|
img_resized = tf.image.resize(img,
|
||||||
|
(self.height, self.width),
|
||||||
|
antialias=True)
|
||||||
|
pred_resized = self.model(img_resized)
|
||||||
|
pred = tf.image.resize(pred_resized,
|
||||||
|
(height, width))
|
||||||
|
return pred
|
||||||
|
|
||||||
|
img_patches = tf.image.extract_patches(
|
||||||
|
images=img,
|
||||||
|
sizes=[1, self.height, self.width, 1],
|
||||||
|
strides=[1, self.stride_y, self.stride_x, 1],
|
||||||
|
rates=[1, 1, 1, 1],
|
||||||
|
padding='SAME')
|
||||||
|
img_patches = tf.squeeze(img_patches)
|
||||||
|
new_shape = (-1, self.height, self.width, 3)
|
||||||
|
img_patches = tf.reshape(img_patches, shape=new_shape)
|
||||||
|
# may be too large:
|
||||||
|
#pred_patches = self.model(img_patches)
|
||||||
|
# so rebatch to fit in memory:
|
||||||
|
img_patches = tf.expand_dims(img_patches, 1)
|
||||||
|
pred_patches = tf.map_fn(self.model, img_patches,
|
||||||
|
parallel_iterations=1,
|
||||||
|
infer_shape=False)
|
||||||
|
pred_patches = tf.squeeze(pred_patches, 1)
|
||||||
|
# calculate corresponding indexes for reconstruction
|
||||||
|
x = tf.range(width)
|
||||||
|
y = tf.range(height)
|
||||||
|
x, y = tf.meshgrid(x, y)
|
||||||
|
indices = tf.stack([y, x], axis=-1)
|
||||||
|
indices_patches = tf.image.extract_patches(
|
||||||
|
images=tf.expand_dims(indices, axis=0),
|
||||||
|
sizes=[1, self.height, self.width, 1],
|
||||||
|
strides=[1, self.stride_y, self.stride_x, 1],
|
||||||
|
rates=[1, 1, 1, 1],
|
||||||
|
padding='SAME')
|
||||||
|
indices_patches = tf.squeeze(indices_patches)
|
||||||
|
indices_patches = tf.reshape(indices_patches, shape=new_shape[:-1] + (2,))
|
||||||
|
|
||||||
|
# use margins for sliding window approach
|
||||||
|
indices_patches = indices_patches * self.window
|
||||||
|
|
||||||
|
pred = tf.scatter_nd(
|
||||||
|
indices_patches,
|
||||||
|
pred_patches,
|
||||||
|
(height, width, self.classes))
|
||||||
|
pred = tf.expand_dims(pred, axis=0)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
@tf.function(reduce_retracing=True,
|
||||||
|
#jit_compile=True, (ScaleAndTranslate and ExtractImagePatches not supported by XLA)
|
||||||
|
input_signature=[tf.TensorSpec([1, None, None, 3],
|
||||||
|
dtype=tf.float32)])
|
||||||
|
def predict(self, x):
|
||||||
|
return self(x)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue