Improved loading of models to allow providing a directory, added a few type-hints and improved the code-style a little bit by running an auto-formatter on the entire file.

pull/48/head
Alexander Pacha 2 years ago
parent f11d0b0bf7
commit b0a8b613e8

@ -1,41 +1,42 @@
""" """
Tool to load model and binarize a given image. Tool to load model and binarize a given image.
""" """
import argparse
import sys import sys
from glob import glob
from os import environ, devnull from os import environ, devnull
from os.path import join from pathlib import Path
from warnings import catch_warnings, simplefilter from typing import Union
import numpy as np
from PIL import Image
import cv2 import cv2
import numpy as np
environ['TF_CPP_MIN_LOG_LEVEL'] = '3' environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr stderr = sys.stderr
sys.stderr = open(devnull, 'w') sys.stderr = open(devnull, 'w')
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from tensorflow.python.keras import backend as tensorflow_backend from tensorflow.python.keras import backend as tensorflow_backend
sys.stderr = stderr
sys.stderr = stderr
import logging import logging
def resize_image(img_in, input_height, input_width): def resize_image(img_in, input_height, input_width):
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
class SbbBinarizer: class SbbBinarizer:
def __init__(self, model_dir, logger=None): def __init__(self, model_dir: Union[str, Path], logger=None):
self.model_dir = model_dir model_dir = Path(model_dir)
self.log = logger if logger else logging.getLogger('SbbBinarizer') self.log = logger if logger else logging.getLogger('SbbBinarizer')
self.start_new_session() self.start_new_session()
self.model_files = glob('%s/*.h5' % self.model_dir) self.model_files = list([str(p.absolute()) for p in model_dir.rglob("*.h5")])
if not self.model_files: if not self.model_files:
raise ValueError(f"No models found in {self.model_dir}") raise ValueError(f"No models found in {str(model_dir)}")
self.models = [] self.models = []
for model_file in self.model_files: for model_file in self.model_files:
@ -53,11 +54,11 @@ class SbbBinarizer:
self.session.close() self.session.close()
del self.session del self.session
def load_model(self, model_name): def load_model(self, model_path: str):
model = load_model(model_name, compile=False) model = load_model(model_path, compile=False)
model_height = model.layers[len(model.layers)-1].output_shape[1] model_height = model.layers[len(model.layers) - 1].output_shape[1]
model_width = model.layers[len(model.layers)-1].output_shape[2] model_width = model.layers[len(model.layers) - 1].output_shape[2]
n_classes = model.layers[len(model.layers)-1].output_shape[3] n_classes = model.layers[len(model.layers) - 1].output_shape[3]
return model, model_height, model_width, n_classes return model, model_height, model_width, n_classes
def predict(self, model_in, img, use_patches): def predict(self, model_in, img, use_patches):
@ -68,40 +69,37 @@ class SbbBinarizer:
img_org_w = img.shape[1] img_org_w = img.shape[1]
if img.shape[0] < model_height and img.shape[1] >= model_width: if img.shape[0] < model_height and img.shape[1] >= model_width:
img_padded = np.zeros(( model_height, img.shape[1], img.shape[2] )) img_padded = np.zeros((model_height, img.shape[1], img.shape[2]))
index_start_h = int( abs( img.shape[0] - model_height) /2.) index_start_h = int(abs(img.shape[0] - model_height) / 2.)
index_start_w = 0 index_start_w = 0
img_padded [ index_start_h: index_start_h+img.shape[0], :, : ] = img[:,:,:] img_padded[index_start_h: index_start_h + img.shape[0], :, :] = img[:, :, :]
elif img.shape[0] >= model_height and img.shape[1] < model_width: elif img.shape[0] >= model_height and img.shape[1] < model_width:
img_padded = np.zeros(( img.shape[0], model_width, img.shape[2] )) img_padded = np.zeros((img.shape[0], model_width, img.shape[2]))
index_start_h = 0 index_start_h = 0
index_start_w = int( abs( img.shape[1] - model_width) /2.) index_start_w = int(abs(img.shape[1] - model_width) / 2.)
img_padded [ :, index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:] img_padded[:, index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
elif img.shape[0] < model_height and img.shape[1] < model_width: elif img.shape[0] < model_height and img.shape[1] < model_width:
img_padded = np.zeros(( model_height, model_width, img.shape[2] )) img_padded = np.zeros((model_height, model_width, img.shape[2]))
index_start_h = int( abs( img.shape[0] - model_height) /2.) index_start_h = int(abs(img.shape[0] - model_height) / 2.)
index_start_w = int( abs( img.shape[1] - model_width) /2.) index_start_w = int(abs(img.shape[1] - model_width) / 2.)
img_padded [ index_start_h: index_start_h+img.shape[0], index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:] img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]
else: else:
index_start_h = 0 index_start_h = 0
index_start_w = 0 index_start_w = 0
img_padded = np.copy(img) img_padded = np.copy(img)
img = np.copy(img_padded) img = np.copy(img_padded)
if use_patches: if use_patches:
margin = int(0.1 * model_width) margin = int(0.1 * model_width)
@ -109,7 +107,6 @@ class SbbBinarizer:
width_mid = model_width - 2 * margin width_mid = model_width - 2 * margin
height_mid = model_height - 2 * margin height_mid = model_height - 2 * margin
img = img / float(255.0) img = img / float(255.0)
img_h = img.shape[0] img_h = img.shape[0]
@ -169,49 +166,49 @@ class SbbBinarizer:
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg mask_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color
elif i == nxf-1 and j == nyf-1: elif i == nxf - 1 and j == nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - 0, :] seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - 0, :]
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - 0] seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - 0]
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0] = seg mask_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0, :] = seg_color prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0, :] = seg_color
elif i == 0 and j == nyf-1: elif i == 0 and j == nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, 0:seg_color.shape[1] - margin, :] seg_color = seg_color[margin:seg_color.shape[0] - 0, 0:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - 0, 0:seg.shape[1] - margin] seg = seg[margin:seg.shape[0] - 0, 0:seg.shape[1] - margin]
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin] = seg mask_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin, :] = seg_color prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin, :] = seg_color
elif i == nxf-1 and j == 0: elif i == nxf - 1 and j == 0:
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :] seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :]
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - 0] seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - 0]
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color
elif i == 0 and j != 0 and j != nyf-1: elif i == 0 and j != 0 and j != nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :] seg_color = seg_color[margin:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - margin, 0:seg.shape[1] - margin] seg = seg[margin:seg.shape[0] - margin, 0:seg.shape[1] - margin]
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg mask_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color
elif i == nxf-1 and j != 0 and j != nyf-1: elif i == nxf - 1 and j != 0 and j != nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :] seg_color = seg_color[margin:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :]
seg = seg[margin:seg.shape[0] - margin, margin:seg.shape[1] - 0] seg = seg[margin:seg.shape[0] - margin, margin:seg.shape[1] - 0]
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color
elif i != 0 and i != nxf-1 and j == 0: elif i != 0 and i != nxf - 1 and j == 0:
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - margin, :] seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - margin, :]
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - margin] seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - margin]
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color
elif i != 0 and i != nxf-1 and j == nyf-1: elif i != 0 and i != nxf - 1 and j == nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - margin, :] seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - margin] seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - margin]
@ -225,9 +222,7 @@ class SbbBinarizer:
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color
prediction_true = prediction_true[index_start_h: index_start_h + img_org_h, index_start_w: index_start_w + img_org_w, :]
prediction_true = prediction_true[index_start_h: index_start_h+img_org_h, index_start_w: index_start_w+img_org_w,:]
prediction_true = prediction_true.astype(np.uint8) prediction_true = prediction_true.astype(np.uint8)
else: else:
@ -242,17 +237,16 @@ class SbbBinarizer:
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = resize_image(seg_color, img_h_page, img_w_page) prediction_true = resize_image(seg_color, img_h_page, img_w_page)
prediction_true = prediction_true.astype(np.uint8) prediction_true = prediction_true.astype(np.uint8)
return prediction_true[:,:,0] return prediction_true[:, :, 0]
def run(self, image=None, image_path=None, save=None, use_patches=False): def run(self, image=None, image_path=None, save=None, use_patches=False):
if (image is not None and image_path is not None) or \ if (image is not None and image_path is not None) or (image is None and image_path is None):
(image is None and image_path is None):
raise ValueError("Must pass either a opencv2 image or an image_path") raise ValueError("Must pass either a opencv2 image or an image_path")
if image_path is not None: if image_path is not None:
image = cv2.imread(image_path) image = cv2.imread(image_path)
img_last = 0 img_last = 0
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) self.log.info(f"Predicting with model {model_file} [{n + 1}/{len(self.model_files)}]")
res = self.predict(model, image, use_patches) res = self.predict(model, image, use_patches)
@ -272,5 +266,7 @@ class SbbBinarizer:
img_last[:, :][img_last[:, :] > 0] = 255 img_last[:, :][img_last[:, :] > 0] = 255
img_last = (img_last[:, :] == 0) * 255 img_last = (img_last[:, :] == 0) * 255
if save: if save:
# Create the output directory (and if necessary it's parents) if it doesn't exist already
Path(save).parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(save, img_last) cv2.imwrite(save, img_last)
return img_last return img_last

Loading…
Cancel
Save