Merge pull request #1 from bertsky/batch-prediction

fixup for batch prediction PR
pull/48/head
Alexander Pacha 2 years ago committed by GitHub
commit 3ade5eccba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,4 +15,4 @@ from .sbb_binarize import SbbBinarizer
def main(model_dir, input_image, output_image): def main(model_dir, input_image, output_image):
binarizer = SbbBinarizer() binarizer = SbbBinarizer()
binarizer.load_model(model_dir) binarizer.load_model(model_dir)
binarizer.binarize_image(image_path=input_image, save_path=output_image) binarizer.binarize_image_file(image_path=input_image, save_path=output_image)

@ -69,7 +69,8 @@ class SbbBinarizeProcessor(Processor):
raise FileNotFoundError("Does not exist or is not a directory: %s" % model_path) raise FileNotFoundError("Does not exist or is not a directory: %s" % model_path)
# resolve relative path via OCR-D ResourceManager # resolve relative path via OCR-D ResourceManager
model_path = self.resolve_resource(str(model_path)) model_path = self.resolve_resource(str(model_path))
self.binarizer = SbbBinarizer(model_dir=model_path, logger=LOG) self.binarizer = SbbBinarizer()
self.binarizer.load_model(model_path)
def process(self): def process(self):
""" """
@ -110,7 +111,7 @@ class SbbBinarizeProcessor(Processor):
if oplevel == 'page': if oplevel == 'page':
LOG.info("Binarizing on 'page' level in page '%s'", page_id) LOG.info("Binarizing on 'page' level in page '%s'", page_id)
bin_image = cv2pil(self.binarizer.run(image=pil2cv(page_image))) bin_image = cv2pil(self.binarizer.binarize_image(pil2cv(page_image)))
# update METS (add the image file): # update METS (add the image file):
bin_image_path = self.workspace.save_image_file(bin_image, bin_image_path = self.workspace.save_image_file(bin_image,
file_id + '.IMG-BIN', file_id + '.IMG-BIN',
@ -124,7 +125,7 @@ class SbbBinarizeProcessor(Processor):
LOG.warning("Page '%s' contains no text/table regions", page_id) LOG.warning("Page '%s' contains no text/table regions", page_id)
for region in regions: for region in regions:
region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh, feature_filter='binarized') region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh, feature_filter='binarized')
region_image_bin = cv2pil(binarizer.run(image=pil2cv(region_image))) region_image_bin = cv2pil(self.binarizer.binarize_image(image=pil2cv(region_image)))
region_image_bin_path = self.workspace.save_image_file( region_image_bin_path = self.workspace.save_image_file(
region_image_bin, region_image_bin,
"%s_%s.IMG-BIN" % (file_id, region.id), "%s_%s.IMG-BIN" % (file_id, region.id),
@ -139,7 +140,7 @@ class SbbBinarizeProcessor(Processor):
LOG.warning("Page '%s' contains no text lines", page_id) LOG.warning("Page '%s' contains no text lines", page_id)
for region_id, line in region_line_tuples: for region_id, line in region_line_tuples:
line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized') line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized')
line_image_bin = cv2pil(binarizer.run(image=pil2cv(line_image))) line_image_bin = cv2pil(self.binarizer.binarize_image(image=pil2cv(line_image)))
line_image_bin_path = self.workspace.save_image_file( line_image_bin_path = self.workspace.save_image_file(
line_image_bin, line_image_bin,
"%s_%s_%s.IMG-BIN" % (file_id, region_id, line.id), "%s_%s_%s.IMG-BIN" % (file_id, region_id, line.id),

@ -3,51 +3,74 @@ import gc
import itertools import itertools
import math import math
import os import os
import sys
from pathlib import Path from pathlib import Path
from typing import Union, List, Any from typing import Union, List, Tuple, Any
import cv2 import cv2
import numpy as np import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(os.devnull, 'w')
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.saving.save import load_model
sys.stderr = stderr
from mpire import WorkerPool from mpire import WorkerPool
from mpire.utils import make_single_arguments from mpire.utils import make_single_arguments
from tensorflow.python.keras.saving.save import load_model
class SbbBinarizer: class SbbBinarizer:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.model: Any = None self.models: List[Tuple[Any, int, int, int]] = []
self.model_height: int = 0
self.model_width: int = 0
self.n_classes: int = 0
def load_model(self, model_dir: Union[str, Path]): def load_model(self, model_dir: Union[str, Path]):
model_dir = Path(model_dir) model_dir = Path(model_dir)
self.model = load_model(str(model_dir.absolute()), compile=False) model_paths = list(model_dir.glob('*.h5')) or list(model_dir.glob('*/'))
self.model_height = self.model.layers[len(self.model.layers) - 1].output_shape[1] for path in model_paths:
self.model_width = self.model.layers[len(self.model.layers) - 1].output_shape[2] model = load_model(str(path.absolute()), compile=False)
self.n_classes = self.model.layers[len(self.model.layers) - 1].output_shape[3] height = model.layers[len(model.layers) - 1].output_shape[1]
width = model.layers[len(model.layers) - 1].output_shape[2]
def binarize_image(self, image_path: Path, save_path: Path): classes = model.layers[len(model.layers) - 1].output_shape[3]
self.models.append((model, height, width, classes))
def binarize_image_file(self, image_path: Path, save_path: Path):
if not image_path.exists(): if not image_path.exists():
raise ValueError(f"Image not found: {str(image_path)}") raise ValueError(f"Image not found: {str(image_path)}")
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
img = cv2.imread(str(image_path)) img = cv2.imread(str(image_path))
original_image_height, original_image_width, image_channels = img.shape
full_image = self.binarize_image(img)
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
def binarize_image(self, img: np.ndarray) -> np.ndarray:
img_last = False
for model, model_height, model_width, _ in self.models:
img_res = self.binarize_image_by_model(img, model, model_height, model_width)
img_last = img_last + (img_res == 0)
img_last = (~img_last).astype(np.uint8) * 255
return img_last
def binarize_image_by_model(self, img: np.ndarray, model: Any, model_height: int, model_width: int) -> np.ndarray:
# Padded images must be multiples of model size # Padded images must be multiples of model size
padded_image_height = math.ceil(original_image_height / self.model_height) * self.model_height original_image_height, original_image_width, image_channels = img.shape
padded_image_width = math.ceil(original_image_width / self.model_width) * self.model_width
padded_image_height = math.ceil(original_image_height / model_height) * model_height
padded_image_width = math.ceil(original_image_width / model_width) * model_width
padded_image = np.zeros((padded_image_height, padded_image_width, image_channels)) padded_image = np.zeros((padded_image_height, padded_image_width, image_channels))
padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :] padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :]
image_batch = np.expand_dims(padded_image, 0) # Create the batch dimension image_batch = np.expand_dims(padded_image, 0) # Create the batch dimension
patches = tf.image.extract_patches( patches = tf.image.extract_patches(
images=image_batch, images=image_batch,
sizes=[1, self.model_height, self.model_width, 1], sizes=[1, model_height, model_width, 1],
strides=[1, self.model_height, self.model_width, 1], strides=[1, model_height, model_width, 1],
rates=[1, 1, 1, 1], rates=[1, 1, 1, 1],
padding='SAME' padding='SAME'
) )
@ -55,17 +78,17 @@ class SbbBinarizer:
number_of_horizontal_patches = patches.shape[1] number_of_horizontal_patches = patches.shape[1]
number_of_vertical_patches = patches.shape[2] number_of_vertical_patches = patches.shape[2]
total_number_of_patches = number_of_horizontal_patches * number_of_vertical_patches total_number_of_patches = number_of_horizontal_patches * number_of_vertical_patches
target_shape = (total_number_of_patches, self.model_height, self.model_width, image_channels) target_shape = (total_number_of_patches, model_height, model_width, image_channels)
# Squeeze all image patches (n, m, width, height, channels) into a single big batch (b, width, height, channels) # Squeeze all image patches (n, m, width, height, channels) into a single big batch (b, width, height, channels)
image_patches = tf.reshape(patches, target_shape) image_patches = tf.reshape(patches, target_shape)
# Normalize the image to values between 0.0 - 1.0 # Normalize the image to values between 0.0 - 1.0
image_patches = image_patches / float(255.0) image_patches = image_patches / float(255.0)
predicted_patches = self.model.predict(image_patches) predicted_patches = model.predict(image_patches, verbose=0)
# We have to manually call garbage collection and clear_session here to avoid memory leaks. # We have to manually call garbage collection and clear_session here to avoid memory leaks.
# Taken from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-issue-in-keras-model-training-e703907a6501 # Taken from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-issue-in-keras-model-training-e703907a6501
gc.collect() #gc.collect()
tf.keras.backend.clear_session() #tf.keras.backend.clear_session()
# The result is a white-on-black image that needs to be inverted to be displayed as black-on-white image # The result is a white-on-black image that needs to be inverted to be displayed as black-on-white image
# We do this by converting the binary values to a boolean numpy-array and then inverting the values # We do this by converting the binary values to a boolean numpy-array and then inverting the values
@ -76,13 +99,11 @@ class SbbBinarizer:
grayscale_patches, grayscale_patches,
padded_image_height, padded_image_height,
padded_image_width, padded_image_width,
self.model_height, model_height,
self.model_width model_width
) )
full_image = full_image_with_padding[0:original_image_height, 0:original_image_width] full_image = full_image_with_padding[0:original_image_height, 0:original_image_width]
Path(save_path).parent.mkdir(parents=True, exist_ok=True) return full_image
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
def _patches_to_image(self, patches: np.ndarray, image_height: int, image_width: int, patch_height: int, patch_width: int): def _patches_to_image(self, patches: np.ndarray, image_height: int, image_width: int, patch_height: int, patch_width: int):
height = math.ceil(image_height / patch_height) * patch_height height = math.ceil(image_height / patch_height) * patch_height

Loading…
Cancel
Save