You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
sbb_binarization/sbb_binarize/sbb_binarize.py

169 lines
7.3 KiB
Python

import argparse
import gc
import itertools
import math
import os
from pathlib import Path
from typing import Union, List, Any
import cv2
import numpy as np
import tensorflow as tf
from mpire import WorkerPool
from mpire.utils import make_single_arguments
from tensorflow.python.keras.saving.save import load_model
class SbbBinarizer:
def __init__(self) -> None:
super().__init__()
self.model: Any = None
self.model_height: int = 0
self.model_width: int = 0
self.n_classes: int = 0
def load_model(self, model_dir: Union[str, Path]):
model_dir = Path(model_dir)
self.model = load_model(str(model_dir.absolute()), compile=False)
self.model_height = self.model.layers[len(self.model.layers) - 1].output_shape[1]
self.model_width = self.model.layers[len(self.model.layers) - 1].output_shape[2]
self.n_classes = self.model.layers[len(self.model.layers) - 1].output_shape[3]
def binarize_image(self, image_path: Path, save_path: Path):
if not image_path.exists():
raise ValueError(f"Image not found: {str(image_path)}")
# Most operations are expecting BGR as this is the standard way how CV2 reads images
# noinspection PyUnresolvedReferences
img = cv2.imread(str(image_path))
original_image_height, original_image_width, image_channels = img.shape
# Padded images must be multiples of model size
padded_image_height = math.ceil(original_image_height / self.model_height) * self.model_height
padded_image_width = math.ceil(original_image_width / self.model_width) * self.model_width
padded_image = np.zeros((padded_image_height, padded_image_width, image_channels))
padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :]
image_batch = np.expand_dims(padded_image, 0) # To create the batch information
patches = tf.image.extract_patches(
images=image_batch,
sizes=[1, self.model_height, self.model_width, 1],
strides=[1, self.model_height, self.model_width, 1],
rates=[1, 1, 1, 1],
padding='SAME'
)
number_of_horizontal_patches = patches.shape[1]
number_of_vertical_patches = patches.shape[2]
total_number_of_patches = number_of_horizontal_patches * number_of_vertical_patches
target_shape = (total_number_of_patches, self.model_height, self.model_width, image_channels)
# Squeeze all image patches (n, m, width, height, channels) into a single big batch (b, width, height, channels)
image_patches = tf.reshape(patches, target_shape)
# Normalize the image to values between 0.0 - 1.0
image_patches = image_patches / float(255.0)
predicted_patches = self.model.predict(image_patches)
# We have to manually call garbage collection and clear_session here to avoid memory leaks.
# Taken from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-issue-in-keras-model-training-e703907a6501
gc.collect()
tf.keras.backend.clear_session()
binary_patches = np.invert(np.argmax(predicted_patches, axis=3).astype(bool)).astype(np.uint8) * 255
full_image_with_padding = self._patches_to_image(
binary_patches,
padded_image_height,
padded_image_width,
self.model_height,
self.model_width
)
full_image = full_image_with_padding[0:original_image_height, 0:original_image_width]
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
def _patches_to_image(
self,
patches: np.ndarray,
image_height: int,
image_width: int,
patch_height: int,
patch_width: int
):
height = math.ceil(image_height / patch_height) * patch_height
width = math.ceil(image_width / patch_width) * patch_width
image_reshaped = np.reshape(
np.squeeze(patches),
[height // patch_height, width // patch_width, patch_height, patch_width]
)
image_transposed = np.transpose(a=image_reshaped, axes=[0, 2, 1, 3])
image_resized = np.reshape(image_transposed, [height, width])
return image_resized
def split_list_into_worker_batches(files: List[Any], number_of_workers: int) -> List[List[Any]]:
""" Splits any given list into batches for the specified number of workers and returns a list of lists. """
batches = []
batch_size = math.ceil(len(files) / number_of_workers)
batch_start = 0
for i in range(1, number_of_workers + 1):
batch_end = i * batch_size
file_batch_to_delete = files[batch_start: batch_end]
batches.append(file_batch_to_delete)
batch_start = batch_end
return batches
def batch_predict(input_data):
model_dir, input_images, output_images, worker_number = input_data
print(f"Setting visible cuda devices to {str(worker_number)}")
os.environ["CUDA_VISIBLE_DEVICES"] = str(worker_number)
binarizer = SbbBinarizer()
binarizer.load_model(model_dir)
for image_path, output_path in zip(input_images, output_images):
binarizer.binarize_image(image_path=image_path, save_path=output_path)
print(f"Binarized {image_path}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_dir', default="model_2021_03_09", help="Path to the directory where the TF model resides or path to an h5 file.")
parser.add_argument('-i', '--input-path', required=True)
parser.add_argument('-o', '--output-path', required=True)
args = parser.parse_args()
input_path = Path(args.input_path)
output_path = Path(args.output_path)
model_directory = args.model_dir
if input_path.is_dir():
print(f"Enumerating all PNG files in {str(input_path)}")
all_input_images = list(input_path.rglob("*.png"))
print(f"Filtering images that have already been binarized in {str(output_path)}")
input_images = [i for i in all_input_images if not (output_path / (i.relative_to(input_path))).exists()]
output_images = [output_path / (i.relative_to(input_path)) for i in input_images]
input_images = [i for i in input_images]
print(f"Starting binarization of {len(input_images)} images")
number_of_gpus = len(tf.config.list_physical_devices('GPU'))
number_of_workers = max(1, number_of_gpus)
image_batches = split_list_into_worker_batches(input_images, number_of_workers)
output_batches = split_list_into_worker_batches(output_images, number_of_workers)
with WorkerPool(n_jobs=number_of_workers, start_method='spawn') as pool:
model_dirs = itertools.repeat(model_directory, len(image_batches))
input_data = zip(model_dirs, image_batches, output_batches, range(number_of_workers))
contents = pool.map_unordered(
batch_predict,
make_single_arguments(input_data),
iterable_len=number_of_workers,
progress_bar=False
)
else:
binarizer = SbbBinarizer()
binarizer.load_model(model_directory)
binarizer.binarize_image(image_path=input_path, save_path=output_path)