Merge pull request #1 from bertsky/batch-prediction

fixup for batch prediction PR
This commit is contained in:
Alexander Pacha 2023-03-19 07:59:36 +01:00 committed by GitHub
commit 3ade5eccba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 30 deletions

View file

@ -15,4 +15,4 @@ from .sbb_binarize import SbbBinarizer
def main(model_dir, input_image, output_image):
binarizer = SbbBinarizer()
binarizer.load_model(model_dir)
binarizer.binarize_image(image_path=input_image, save_path=output_image)
binarizer.binarize_image_file(image_path=input_image, save_path=output_image)

View file

@ -69,7 +69,8 @@ class SbbBinarizeProcessor(Processor):
raise FileNotFoundError("Does not exist or is not a directory: %s" % model_path)
# resolve relative path via OCR-D ResourceManager
model_path = self.resolve_resource(str(model_path))
self.binarizer = SbbBinarizer(model_dir=model_path, logger=LOG)
self.binarizer = SbbBinarizer()
self.binarizer.load_model(model_path)
def process(self):
"""
@ -110,7 +111,7 @@ class SbbBinarizeProcessor(Processor):
if oplevel == 'page':
LOG.info("Binarizing on 'page' level in page '%s'", page_id)
bin_image = cv2pil(self.binarizer.run(image=pil2cv(page_image)))
bin_image = cv2pil(self.binarizer.binarize_image(pil2cv(page_image)))
# update METS (add the image file):
bin_image_path = self.workspace.save_image_file(bin_image,
file_id + '.IMG-BIN',
@ -124,7 +125,7 @@ class SbbBinarizeProcessor(Processor):
LOG.warning("Page '%s' contains no text/table regions", page_id)
for region in regions:
region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh, feature_filter='binarized')
region_image_bin = cv2pil(binarizer.run(image=pil2cv(region_image)))
region_image_bin = cv2pil(self.binarizer.binarize_image(image=pil2cv(region_image)))
region_image_bin_path = self.workspace.save_image_file(
region_image_bin,
"%s_%s.IMG-BIN" % (file_id, region.id),
@ -139,7 +140,7 @@ class SbbBinarizeProcessor(Processor):
LOG.warning("Page '%s' contains no text lines", page_id)
for region_id, line in region_line_tuples:
line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized')
line_image_bin = cv2pil(binarizer.run(image=pil2cv(line_image)))
line_image_bin = cv2pil(self.binarizer.binarize_image(image=pil2cv(line_image)))
line_image_bin_path = self.workspace.save_image_file(
line_image_bin,
"%s_%s_%s.IMG-BIN" % (file_id, region_id, line.id),

View file

@ -3,51 +3,74 @@ import gc
import itertools
import math
import os
import sys
from pathlib import Path
from typing import Union, List, Any
from typing import Union, List, Tuple, Any
import cv2
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(os.devnull, 'w')
import tensorflow as tf
from tensorflow.python.keras.saving.save import load_model
sys.stderr = stderr
from mpire import WorkerPool
from mpire.utils import make_single_arguments
from tensorflow.python.keras.saving.save import load_model
class SbbBinarizer:
def __init__(self) -> None:
super().__init__()
self.model: Any = None
self.model_height: int = 0
self.model_width: int = 0
self.n_classes: int = 0
self.models: List[Tuple[Any, int, int, int]] = []
def load_model(self, model_dir: Union[str, Path]):
model_dir = Path(model_dir)
self.model = load_model(str(model_dir.absolute()), compile=False)
self.model_height = self.model.layers[len(self.model.layers) - 1].output_shape[1]
self.model_width = self.model.layers[len(self.model.layers) - 1].output_shape[2]
self.n_classes = self.model.layers[len(self.model.layers) - 1].output_shape[3]
model_paths = list(model_dir.glob('*.h5')) or list(model_dir.glob('*/'))
for path in model_paths:
model = load_model(str(path.absolute()), compile=False)
height = model.layers[len(model.layers) - 1].output_shape[1]
width = model.layers[len(model.layers) - 1].output_shape[2]
classes = model.layers[len(model.layers) - 1].output_shape[3]
self.models.append((model, height, width, classes))
def binarize_image(self, image_path: Path, save_path: Path):
def binarize_image_file(self, image_path: Path, save_path: Path):
if not image_path.exists():
raise ValueError(f"Image not found: {str(image_path)}")
# noinspection PyUnresolvedReferences
img = cv2.imread(str(image_path))
full_image = self.binarize_image(img)
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
def binarize_image(self, img: np.ndarray) -> np.ndarray:
img_last = False
for model, model_height, model_width, _ in self.models:
img_res = self.binarize_image_by_model(img, model, model_height, model_width)
img_last = img_last + (img_res == 0)
img_last = (~img_last).astype(np.uint8) * 255
return img_last
def binarize_image_by_model(self, img: np.ndarray, model: Any, model_height: int, model_width: int) -> np.ndarray:
# Padded images must be multiples of model size
original_image_height, original_image_width, image_channels = img.shape
# Padded images must be multiples of model size
padded_image_height = math.ceil(original_image_height / self.model_height) * self.model_height
padded_image_width = math.ceil(original_image_width / self.model_width) * self.model_width
padded_image_height = math.ceil(original_image_height / model_height) * model_height
padded_image_width = math.ceil(original_image_width / model_width) * model_width
padded_image = np.zeros((padded_image_height, padded_image_width, image_channels))
padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :]
image_batch = np.expand_dims(padded_image, 0) # Create the batch dimension
patches = tf.image.extract_patches(
images=image_batch,
sizes=[1, self.model_height, self.model_width, 1],
strides=[1, self.model_height, self.model_width, 1],
sizes=[1, model_height, model_width, 1],
strides=[1, model_height, model_width, 1],
rates=[1, 1, 1, 1],
padding='SAME'
)
@ -55,17 +78,17 @@ class SbbBinarizer:
number_of_horizontal_patches = patches.shape[1]
number_of_vertical_patches = patches.shape[2]
total_number_of_patches = number_of_horizontal_patches * number_of_vertical_patches
target_shape = (total_number_of_patches, self.model_height, self.model_width, image_channels)
target_shape = (total_number_of_patches, model_height, model_width, image_channels)
# Squeeze all image patches (n, m, width, height, channels) into a single big batch (b, width, height, channels)
image_patches = tf.reshape(patches, target_shape)
# Normalize the image to values between 0.0 - 1.0
image_patches = image_patches / float(255.0)
predicted_patches = self.model.predict(image_patches)
predicted_patches = model.predict(image_patches, verbose=0)
# We have to manually call garbage collection and clear_session here to avoid memory leaks.
# Taken from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-issue-in-keras-model-training-e703907a6501
gc.collect()
tf.keras.backend.clear_session()
#gc.collect()
#tf.keras.backend.clear_session()
# The result is a white-on-black image that needs to be inverted to be displayed as black-on-white image
# We do this by converting the binary values to a boolean numpy-array and then inverting the values
@ -76,13 +99,11 @@ class SbbBinarizer:
grayscale_patches,
padded_image_height,
padded_image_width,
self.model_height,
self.model_width
model_height,
model_width
)
full_image = full_image_with_padding[0:original_image_height, 0:original_image_width]
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
return full_image
def _patches_to_image(self, patches: np.ndarray, image_height: int, image_width: int, patch_height: int, patch_width: int):
height = math.ceil(image_height / patch_height) * patch_height