reinstate ensemble combination and in-memory prediction

pull/48/head
Robert Sachunsky 2 years ago
parent cade5dda73
commit 342e94e287

@ -15,4 +15,4 @@ from .sbb_binarize import SbbBinarizer
def main(model_dir, input_image, output_image): def main(model_dir, input_image, output_image):
binarizer = SbbBinarizer() binarizer = SbbBinarizer()
binarizer.load_model(model_dir) binarizer.load_model(model_dir)
binarizer.binarize_image(image_path=input_image, save_path=output_image) binarizer.binarize_image_file(image_path=input_image, save_path=output_image)

@ -5,7 +5,7 @@ import math
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Union, List, Any from typing import Union, List, Tuple, Any
import cv2 import cv2
import numpy as np import numpy as np
@ -24,37 +24,53 @@ from mpire.utils import make_single_arguments
class SbbBinarizer: class SbbBinarizer:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.model: Any = None self.models: List[Tuple[Any, int, int, int]] = []
self.model_height: int = 0
self.model_width: int = 0
self.n_classes: int = 0
def load_model(self, model_dir: Union[str, Path]): def load_model(self, model_dir: Union[str, Path]):
model_dir = Path(model_dir) model_dir = Path(model_dir)
self.model = load_model(str(model_dir.absolute()), compile=False) model_paths = list(model_dir.glob('*.h5')) or list(model_dir.glob('*/'))
self.model_height = self.model.layers[len(self.model.layers) - 1].output_shape[1] for path in model_paths:
self.model_width = self.model.layers[len(self.model.layers) - 1].output_shape[2] model = load_model(str(path.absolute()), compile=False)
self.n_classes = self.model.layers[len(self.model.layers) - 1].output_shape[3] height = model.layers[len(model.layers) - 1].output_shape[1]
width = model.layers[len(model.layers) - 1].output_shape[2]
def binarize_image(self, image_path: Path, save_path: Path): classes = model.layers[len(model.layers) - 1].output_shape[3]
self.models.append((model, height, width, classes))
def binarize_image_file(self, image_path: Path, save_path: Path):
if not image_path.exists(): if not image_path.exists():
raise ValueError(f"Image not found: {str(image_path)}") raise ValueError(f"Image not found: {str(image_path)}")
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
img = cv2.imread(str(image_path)) img = cv2.imread(str(image_path))
original_image_height, original_image_width, image_channels = img.shape
full_image = self.binarize_image(img)
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
def binarize_image(self, img: np.ndarray) -> np.ndarray:
img_last = False
for model, model_height, model_width, _ in self.models:
img_res = self.binarize_image_by_model(img, model, model_height, model_width)
img_last = img_last + (img_res == 0)
img_last = (~img_last).astype(np.uint8) * 255
return img_last
def binarize_image_by_model(self, img: np.ndarray, model: Any, model_height: int, model_width: int) -> np.ndarray:
# Padded images must be multiples of model size # Padded images must be multiples of model size
padded_image_height = math.ceil(original_image_height / self.model_height) * self.model_height original_image_height, original_image_width, image_channels = img.shape
padded_image_width = math.ceil(original_image_width / self.model_width) * self.model_width
padded_image_height = math.ceil(original_image_height / model_height) * model_height
padded_image_width = math.ceil(original_image_width / model_width) * model_width
padded_image = np.zeros((padded_image_height, padded_image_width, image_channels)) padded_image = np.zeros((padded_image_height, padded_image_width, image_channels))
padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :] padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :]
image_batch = np.expand_dims(padded_image, 0) # Create the batch dimension image_batch = np.expand_dims(padded_image, 0) # Create the batch dimension
patches = tf.image.extract_patches( patches = tf.image.extract_patches(
images=image_batch, images=image_batch,
sizes=[1, self.model_height, self.model_width, 1], sizes=[1, model_height, model_width, 1],
strides=[1, self.model_height, self.model_width, 1], strides=[1, model_height, model_width, 1],
rates=[1, 1, 1, 1], rates=[1, 1, 1, 1],
padding='SAME' padding='SAME'
) )
@ -62,13 +78,13 @@ class SbbBinarizer:
number_of_horizontal_patches = patches.shape[1] number_of_horizontal_patches = patches.shape[1]
number_of_vertical_patches = patches.shape[2] number_of_vertical_patches = patches.shape[2]
total_number_of_patches = number_of_horizontal_patches * number_of_vertical_patches total_number_of_patches = number_of_horizontal_patches * number_of_vertical_patches
target_shape = (total_number_of_patches, self.model_height, self.model_width, image_channels) target_shape = (total_number_of_patches, model_height, model_width, image_channels)
# Squeeze all image patches (n, m, width, height, channels) into a single big batch (b, width, height, channels) # Squeeze all image patches (n, m, width, height, channels) into a single big batch (b, width, height, channels)
image_patches = tf.reshape(patches, target_shape) image_patches = tf.reshape(patches, target_shape)
# Normalize the image to values between 0.0 - 1.0 # Normalize the image to values between 0.0 - 1.0
image_patches = image_patches / float(255.0) image_patches = image_patches / float(255.0)
predicted_patches = self.model.predict(image_patches) predicted_patches = model.predict(image_patches, verbose=0)
# We have to manually call garbage collection and clear_session here to avoid memory leaks. # We have to manually call garbage collection and clear_session here to avoid memory leaks.
# Taken from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-issue-in-keras-model-training-e703907a6501 # Taken from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-issue-in-keras-model-training-e703907a6501
gc.collect() gc.collect()
@ -83,13 +99,11 @@ class SbbBinarizer:
grayscale_patches, grayscale_patches,
padded_image_height, padded_image_height,
padded_image_width, padded_image_width,
self.model_height, model_height,
self.model_width model_width
) )
full_image = full_image_with_padding[0:original_image_height, 0:original_image_width] full_image = full_image_with_padding[0:original_image_height, 0:original_image_width]
Path(save_path).parent.mkdir(parents=True, exist_ok=True) return full_image
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)
def _patches_to_image(self, patches: np.ndarray, image_height: int, image_width: int, patch_height: int, patch_width: int): def _patches_to_image(self, patches: np.ndarray, image_height: int, image_width: int, patch_height: int, patch_width: int):
height = math.ceil(image_height / patch_height) * patch_height height = math.ceil(image_height / patch_height) * patch_height

Loading…
Cancel
Save