Rewrote binarization script to always use patches, but in a much more efficient way and adding support for batch-conversion with multiple GPUs.

This commit is contained in:
Alexander Pacha 2022-08-30 10:40:54 +02:00
parent b0a8b613e8
commit 4112c6fe71
2 changed files with 136 additions and 239 deletions

View file

@ -3,3 +3,4 @@ setuptools >= 41
opencv-python-headless opencv-python-headless
ocrd >= 2.22.3 ocrd >= 2.22.3
tensorflow >= 2.4.0 tensorflow >= 2.4.0
mpire

View file

@ -1,272 +1,168 @@
"""
Tool to load model and binarize a given image.
"""
import argparse import argparse
import sys import gc
from os import environ, devnull import itertools
import math
import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union, List, Any
import cv2 import cv2
import numpy as np import numpy as np
environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(devnull, 'w')
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import load_model from mpire import WorkerPool
from tensorflow.python.keras import backend as tensorflow_backend from mpire.utils import make_single_arguments
from tensorflow.python.keras.saving.save import load_model
sys.stderr = stderr
import logging
def resize_image(img_in, input_height, input_width):
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
class SbbBinarizer: class SbbBinarizer:
def __init__(self) -> None:
super().__init__()
self.model: Any = None
self.model_height: int = 0
self.model_width: int = 0
self.n_classes: int = 0
def __init__(self, model_dir: Union[str, Path], logger=None): def load_model(self, model_dir: Union[str, Path]):
model_dir = Path(model_dir) model_dir = Path(model_dir)
self.log = logger if logger else logging.getLogger('SbbBinarizer') self.model = load_model(str(model_dir.absolute()), compile=False)
self.model_height = self.model.layers[len(self.model.layers) - 1].output_shape[1]
self.model_width = self.model.layers[len(self.model.layers) - 1].output_shape[2]
self.n_classes = self.model.layers[len(self.model.layers) - 1].output_shape[3]
self.start_new_session() def binarize_image(self, image_path: Path, save_path: Path):
if not image_path.exists():
raise ValueError(f"Image not found: {str(image_path)}")
self.model_files = list([str(p.absolute()) for p in model_dir.rglob("*.h5")]) # Most operations are expecting BGR as this is the standard way how CV2 reads images
if not self.model_files: # noinspection PyUnresolvedReferences
raise ValueError(f"No models found in {str(model_dir)}") img = cv2.imread(str(image_path))
original_image_height, original_image_width, image_channels = img.shape
self.models = [] # Padded images must be multiples of model size
for model_file in self.model_files: padded_image_height = math.ceil(original_image_height / self.model_height) * self.model_height
self.models.append(self.load_model(model_file)) padded_image_width = math.ceil(original_image_width / self.model_width) * self.model_width
padded_image = np.zeros((padded_image_height, padded_image_width, image_channels))
padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :]
def start_new_session(self): image_batch = np.expand_dims(padded_image, 0) # To create the batch information
config = tf.compat.v1.ConfigProto() patches = tf.image.extract_patches(
config.gpu_options.allow_growth = True images=image_batch,
sizes=[1, self.model_height, self.model_width, 1],
strides=[1, self.model_height, self.model_width, 1],
rates=[1, 1, 1, 1],
padding='SAME'
)
self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() number_of_horizontal_patches = patches.shape[1]
tensorflow_backend.set_session(self.session) number_of_vertical_patches = patches.shape[2]
total_number_of_patches = number_of_horizontal_patches * number_of_vertical_patches
target_shape = (total_number_of_patches, self.model_height, self.model_width, image_channels)
# Squeeze all image patches (n, m, width, height, channels) into a single big batch (b, width, height, channels)
image_patches = tf.reshape(patches, target_shape)
# Normalize the image to values between 0.0 - 1.0
image_patches = image_patches / float(255.0)
def end_session(self): predicted_patches = self.model.predict(image_patches)
tensorflow_backend.clear_session() # We have to manually call garbage collection and clear_session here to avoid memory leaks.
self.session.close() # Taken from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-issue-in-keras-model-training-e703907a6501
del self.session gc.collect()
tf.keras.backend.clear_session()
def load_model(self, model_path: str): binary_patches = np.invert(np.argmax(predicted_patches, axis=3).astype(bool)).astype(np.uint8) * 255
model = load_model(model_path, compile=False) full_image_with_padding = self._patches_to_image(
model_height = model.layers[len(model.layers) - 1].output_shape[1] binary_patches,
model_width = model.layers[len(model.layers) - 1].output_shape[2] padded_image_height,
n_classes = model.layers[len(model.layers) - 1].output_shape[3] padded_image_width,
return model, model_height, model_width, n_classes self.model_height,
self.model_width
)
full_image = full_image_with_padding[0:original_image_height, 0:original_image_width]
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
def predict(self, model_in, img, use_patches): def _patches_to_image(
tensorflow_backend.set_session(self.session) self,
model, model_height, model_width, n_classes = model_in patches: np.ndarray,
image_height: int,
image_width: int,
patch_height: int,
patch_width: int
):
height = math.ceil(image_height / patch_height) * patch_height
width = math.ceil(image_width / patch_width) * patch_width
img_org_h = img.shape[0] image_reshaped = np.reshape(
img_org_w = img.shape[1] np.squeeze(patches),
[height // patch_height, width // patch_width, patch_height, patch_width]
if img.shape[0] < model_height and img.shape[1] >= model_width: )
img_padded = np.zeros((model_height, img.shape[1], img.shape[2])) image_transposed = np.transpose(a=image_reshaped, axes=[0, 2, 1, 3])
image_resized = np.reshape(image_transposed, [height, width])
index_start_h = int(abs(img.shape[0] - model_height) / 2.) return image_resized
index_start_w = 0
img_padded[index_start_h: index_start_h + img.shape[0], :, :] = img[:, :, :]
elif img.shape[0] >= model_height and img.shape[1] < model_width:
img_padded = np.zeros((img.shape[0], model_width, img.shape[2]))
index_start_h = 0
index_start_w = int(abs(img.shape[1] - model_width) / 2.)
img_padded[:, index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
elif img.shape[0] < model_height and img.shape[1] < model_width: def split_list_into_worker_batches(files: List[Any], number_of_workers: int) -> List[List[Any]]:
img_padded = np.zeros((model_height, model_width, img.shape[2])) """ Splits any given list into batches for the specified number of workers and returns a list of lists. """
batches = []
batch_size = math.ceil(len(files) / number_of_workers)
batch_start = 0
for i in range(1, number_of_workers + 1):
batch_end = i * batch_size
file_batch_to_delete = files[batch_start: batch_end]
batches.append(file_batch_to_delete)
batch_start = batch_end
return batches
index_start_h = int(abs(img.shape[0] - model_height) / 2.)
index_start_w = int(abs(img.shape[1] - model_width) / 2.)
img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :] def batch_predict(input_data):
model_dir, input_images, output_images, worker_number = input_data
print(f"Setting visible cuda devices to {str(worker_number)}")
os.environ["CUDA_VISIBLE_DEVICES"] = str(worker_number)
binarizer = SbbBinarizer()
binarizer.load_model(model_dir)
for image_path, output_path in zip(input_images, output_images):
binarizer.binarize_image(image_path=image_path, save_path=output_path)
print(f"Binarized {image_path}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_dir', default="model_2021_03_09", help="Path to the directory where the TF model resides or path to an h5 file.")
parser.add_argument('-i', '--input-path', required=True)
parser.add_argument('-o', '--output-path', required=True)
args = parser.parse_args()
input_path = Path(args.input_path)
output_path = Path(args.output_path)
model_directory = args.model_dir
if input_path.is_dir():
print(f"Enumerating all PNG files in {str(input_path)}")
all_input_images = list(input_path.rglob("*.png"))
print(f"Filtering images that have already been binarized in {str(output_path)}")
input_images = [i for i in all_input_images if not (output_path / (i.relative_to(input_path))).exists()]
output_images = [output_path / (i.relative_to(input_path)) for i in input_images]
input_images = [i for i in input_images]
print(f"Starting binarization of {len(input_images)} images")
number_of_gpus = len(tf.config.list_physical_devices('GPU'))
number_of_workers = max(1, number_of_gpus)
image_batches = split_list_into_worker_batches(input_images, number_of_workers)
output_batches = split_list_into_worker_batches(output_images, number_of_workers)
with WorkerPool(n_jobs=number_of_workers, start_method='spawn') as pool:
model_dirs = itertools.repeat(model_directory, len(image_batches))
input_data = zip(model_dirs, image_batches, output_batches, range(number_of_workers))
contents = pool.map_unordered(
batch_predict,
make_single_arguments(input_data),
iterable_len=number_of_workers,
progress_bar=False
)
else: else:
index_start_h = 0 binarizer = SbbBinarizer()
index_start_w = 0 binarizer.load_model(model_directory)
img_padded = np.copy(img) binarizer.binarize_image(image_path=input_path, save_path=output_path)
img = np.copy(img_padded)
if use_patches:
margin = int(0.1 * model_width)
width_mid = model_width - 2 * margin
height_mid = model_height - 2 * margin
img = img / float(255.0)
img_h = img.shape[0]
img_w = img.shape[1]
prediction_true = np.zeros((img_h, img_w, 3))
mask_true = np.zeros((img_h, img_w))
nxf = img_w / float(width_mid)
nyf = img_h / float(height_mid)
if nxf > int(nxf):
nxf = int(nxf) + 1
else:
nxf = int(nxf)
if nyf > int(nyf):
nyf = int(nyf) + 1
else:
nyf = int(nyf)
for i in range(nxf):
for j in range(nyf):
if i == 0:
index_x_d = i * width_mid
index_x_u = index_x_d + model_width
elif i > 0:
index_x_d = i * width_mid
index_x_u = index_x_d + model_width
if j == 0:
index_y_d = j * height_mid
index_y_u = index_y_d + model_height
elif j > 0:
index_y_d = j * height_mid
index_y_u = index_y_d + model_height
if index_x_u > img_w:
index_x_u = img_w
index_x_d = img_w - model_width
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - model_height
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]))
seg = np.argmax(label_p_pred, axis=3)[0]
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
if i == 0 and j == 0:
seg_color = seg_color[0:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :]
seg = seg[0:seg.shape[0] - margin, 0:seg.shape[1] - margin]
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color
elif i == nxf - 1 and j == nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - 0, :]
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - 0]
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0, :] = seg_color
elif i == 0 and j == nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, 0:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - 0, 0:seg.shape[1] - margin]
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin, :] = seg_color
elif i == nxf - 1 and j == 0:
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :]
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - 0]
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color
elif i == 0 and j != 0 and j != nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - margin, 0:seg.shape[1] - margin]
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color
elif i == nxf - 1 and j != 0 and j != nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :]
seg = seg[margin:seg.shape[0] - margin, margin:seg.shape[1] - 0]
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color
elif i != 0 and i != nxf - 1 and j == 0:
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - margin, :]
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - margin]
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color
elif i != 0 and i != nxf - 1 and j == nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - margin]
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - margin, :] = seg_color
else:
seg_color = seg_color[margin:seg_color.shape[0] - margin, margin:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - margin, margin:seg.shape[1] - margin]
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color
prediction_true = prediction_true[index_start_h: index_start_h + img_org_h, index_start_w: index_start_w + img_org_w, :]
prediction_true = prediction_true.astype(np.uint8)
else:
img_h_page = img.shape[0]
img_w_page = img.shape[1]
img = img / float(255.0)
img = resize_image(img, model_height, model_width)
label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
seg = np.argmax(label_p_pred, axis=3)[0]
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = resize_image(seg_color, img_h_page, img_w_page)
prediction_true = prediction_true.astype(np.uint8)
return prediction_true[:, :, 0]
def run(self, image=None, image_path=None, save=None, use_patches=False):
if (image is not None and image_path is not None) or (image is None and image_path is None):
raise ValueError("Must pass either a opencv2 image or an image_path")
if image_path is not None:
image = cv2.imread(image_path)
img_last = 0
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
self.log.info(f"Predicting with model {model_file} [{n + 1}/{len(self.model_files)}]")
res = self.predict(model, image, use_patches)
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2
res = res - 1
res = res * 255
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res
img_fin = img_fin.astype(np.uint8)
img_fin = (res[:, :] == 0) * 255
img_last = img_last + img_fin
kernel = np.ones((5, 5), np.uint8)
img_last[:, :][img_last[:, :] > 0] = 255
img_last = (img_last[:, :] == 0) * 255
if save:
# Create the output directory (and if necessary it's parents) if it doesn't exist already
Path(save).parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(save, img_last)
return img_last