Merge pull request #1 from bertsky/batch-prediction

fixup for batch prediction PR
pull/48/head
Alexander Pacha 2 years ago committed by GitHub
commit 3ade5eccba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,4 +15,4 @@ from .sbb_binarize import SbbBinarizer
def main(model_dir, input_image, output_image):
binarizer = SbbBinarizer()
binarizer.load_model(model_dir)
binarizer.binarize_image(image_path=input_image, save_path=output_image)
binarizer.binarize_image_file(image_path=input_image, save_path=output_image)

@ -69,7 +69,8 @@ class SbbBinarizeProcessor(Processor):
raise FileNotFoundError("Does not exist or is not a directory: %s" % model_path)
# resolve relative path via OCR-D ResourceManager
model_path = self.resolve_resource(str(model_path))
self.binarizer = SbbBinarizer(model_dir=model_path, logger=LOG)
self.binarizer = SbbBinarizer()
self.binarizer.load_model(model_path)
def process(self):
"""
@ -110,7 +111,7 @@ class SbbBinarizeProcessor(Processor):
if oplevel == 'page':
LOG.info("Binarizing on 'page' level in page '%s'", page_id)
bin_image = cv2pil(self.binarizer.run(image=pil2cv(page_image)))
bin_image = cv2pil(self.binarizer.binarize_image(pil2cv(page_image)))
# update METS (add the image file):
bin_image_path = self.workspace.save_image_file(bin_image,
file_id + '.IMG-BIN',
@ -124,7 +125,7 @@ class SbbBinarizeProcessor(Processor):
LOG.warning("Page '%s' contains no text/table regions", page_id)
for region in regions:
region_image, region_xywh = self.workspace.image_from_segment(region, page_image, page_xywh, feature_filter='binarized')
region_image_bin = cv2pil(binarizer.run(image=pil2cv(region_image)))
region_image_bin = cv2pil(self.binarizer.binarize_image(image=pil2cv(region_image)))
region_image_bin_path = self.workspace.save_image_file(
region_image_bin,
"%s_%s.IMG-BIN" % (file_id, region.id),
@ -139,7 +140,7 @@ class SbbBinarizeProcessor(Processor):
LOG.warning("Page '%s' contains no text lines", page_id)
for region_id, line in region_line_tuples:
line_image, line_xywh = self.workspace.image_from_segment(line, page_image, page_xywh, feature_filter='binarized')
line_image_bin = cv2pil(binarizer.run(image=pil2cv(line_image)))
line_image_bin = cv2pil(self.binarizer.binarize_image(image=pil2cv(line_image)))
line_image_bin_path = self.workspace.save_image_file(
line_image_bin,
"%s_%s_%s.IMG-BIN" % (file_id, region_id, line.id),

@ -3,51 +3,74 @@ import gc
import itertools
import math
import os
import sys
from pathlib import Path
from typing import Union, List, Any
from typing import Union, List, Tuple, Any
import cv2
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(os.devnull, 'w')
import tensorflow as tf
from tensorflow.python.keras.saving.save import load_model
sys.stderr = stderr
from mpire import WorkerPool
from mpire.utils import make_single_arguments
from tensorflow.python.keras.saving.save import load_model
class SbbBinarizer:
def __init__(self) -> None:
super().__init__()
self.model: Any = None
self.model_height: int = 0
self.model_width: int = 0
self.n_classes: int = 0
self.models: List[Tuple[Any, int, int, int]] = []
def load_model(self, model_dir: Union[str, Path]):
model_dir = Path(model_dir)
self.model = load_model(str(model_dir.absolute()), compile=False)
self.model_height = self.model.layers[len(self.model.layers) - 1].output_shape[1]
self.model_width = self.model.layers[len(self.model.layers) - 1].output_shape[2]
self.n_classes = self.model.layers[len(self.model.layers) - 1].output_shape[3]
def binarize_image(self, image_path: Path, save_path: Path):
model_paths = list(model_dir.glob('*.h5')) or list(model_dir.glob('*/'))
for path in model_paths:
model = load_model(str(path.absolute()), compile=False)
height = model.layers[len(model.layers) - 1].output_shape[1]
width = model.layers[len(model.layers) - 1].output_shape[2]
classes = model.layers[len(model.layers) - 1].output_shape[3]
self.models.append((model, height, width, classes))
def binarize_image_file(self, image_path: Path, save_path: Path):
if not image_path.exists():
raise ValueError(f"Image not found: {str(image_path)}")
# noinspection PyUnresolvedReferences
img = cv2.imread(str(image_path))
original_image_height, original_image_width, image_channels = img.shape
full_image = self.binarize_image(img)
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
def binarize_image(self, img: np.ndarray) -> np.ndarray:
img_last = False
for model, model_height, model_width, _ in self.models:
img_res = self.binarize_image_by_model(img, model, model_height, model_width)
img_last = img_last + (img_res == 0)
img_last = (~img_last).astype(np.uint8) * 255
return img_last
def binarize_image_by_model(self, img: np.ndarray, model: Any, model_height: int, model_width: int) -> np.ndarray:
# Padded images must be multiples of model size
padded_image_height = math.ceil(original_image_height / self.model_height) * self.model_height
padded_image_width = math.ceil(original_image_width / self.model_width) * self.model_width
original_image_height, original_image_width, image_channels = img.shape
padded_image_height = math.ceil(original_image_height / model_height) * model_height
padded_image_width = math.ceil(original_image_width / model_width) * model_width
padded_image = np.zeros((padded_image_height, padded_image_width, image_channels))
padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :]
image_batch = np.expand_dims(padded_image, 0) # Create the batch dimension
patches = tf.image.extract_patches(
images=image_batch,
sizes=[1, self.model_height, self.model_width, 1],
strides=[1, self.model_height, self.model_width, 1],
sizes=[1, model_height, model_width, 1],
strides=[1, model_height, model_width, 1],
rates=[1, 1, 1, 1],
padding='SAME'
)
@ -55,17 +78,17 @@ class SbbBinarizer:
number_of_horizontal_patches = patches.shape[1]
number_of_vertical_patches = patches.shape[2]
total_number_of_patches = number_of_horizontal_patches * number_of_vertical_patches
target_shape = (total_number_of_patches, self.model_height, self.model_width, image_channels)
target_shape = (total_number_of_patches, model_height, model_width, image_channels)
# Squeeze all image patches (n, m, width, height, channels) into a single big batch (b, width, height, channels)
image_patches = tf.reshape(patches, target_shape)
# Normalize the image to values between 0.0 - 1.0
image_patches = image_patches / float(255.0)
predicted_patches = self.model.predict(image_patches)
predicted_patches = model.predict(image_patches, verbose=0)
# We have to manually call garbage collection and clear_session here to avoid memory leaks.
# Taken from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-issue-in-keras-model-training-e703907a6501
gc.collect()
tf.keras.backend.clear_session()
#gc.collect()
#tf.keras.backend.clear_session()
# The result is a white-on-black image that needs to be inverted to be displayed as black-on-white image
# We do this by converting the binary values to a boolean numpy-array and then inverting the values
@ -76,13 +99,11 @@ class SbbBinarizer:
grayscale_patches,
padded_image_height,
padded_image_width,
self.model_height,
self.model_width
model_height,
model_width
)
full_image = full_image_with_padding[0:original_image_height, 0:original_image_width]
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
return full_image
def _patches_to_image(self, patches: np.ndarray, image_height: int, image_width: int, patch_height: int, patch_width: int):
height = math.ceil(image_height / patch_height) * patch_height

Loading…
Cancel
Save