mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-24 16:12:03 +01:00
use (generalized) do_prediction() instead of predict_enhancement()
This commit is contained in:
parent
341480e9a0
commit
41dccb216c
2 changed files with 17 additions and 102 deletions
|
|
@ -187,7 +187,6 @@ def layout_cli(
|
||||||
assert enable_plotting or not save_all, "Plotting with -sa also requires -ep"
|
assert enable_plotting or not save_all, "Plotting with -sa also requires -ep"
|
||||||
assert enable_plotting or not save_page, "Plotting with -sp also requires -ep"
|
assert enable_plotting or not save_page, "Plotting with -sp also requires -ep"
|
||||||
assert enable_plotting or not save_images, "Plotting with -si also requires -ep"
|
assert enable_plotting or not save_images, "Plotting with -si also requires -ep"
|
||||||
assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep"
|
|
||||||
assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \
|
assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \
|
||||||
"Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae"
|
"Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae"
|
||||||
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."
|
||||||
|
|
|
||||||
|
|
@ -192,6 +192,7 @@ class Eynollah:
|
||||||
loadable = [
|
loadable = [
|
||||||
"col_classifier",
|
"col_classifier",
|
||||||
"binarization",
|
"binarization",
|
||||||
|
#"enhancement",
|
||||||
"page",
|
"page",
|
||||||
"region"
|
"region"
|
||||||
]
|
]
|
||||||
|
|
@ -256,103 +257,6 @@ class Eynollah:
|
||||||
key += '_uint8'
|
key += '_uint8'
|
||||||
return self._imgs[key].copy()
|
return self._imgs[key].copy()
|
||||||
|
|
||||||
def predict_enhancement(self, img):
|
|
||||||
self.logger.debug("enter predict_enhancement")
|
|
||||||
|
|
||||||
img_height_model = self.model_zoo.get("enhancement").layers[-1].output_shape[1]
|
|
||||||
img_width_model = self.model_zoo.get("enhancement").layers[-1].output_shape[2]
|
|
||||||
if img.shape[0] < img_height_model:
|
|
||||||
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
|
|
||||||
if img.shape[1] < img_width_model:
|
|
||||||
img = cv2.resize(img, (img_height_model, img.shape[0]), interpolation=cv2.INTER_NEAREST)
|
|
||||||
margin = int(0 * img_width_model)
|
|
||||||
width_mid = img_width_model - 2 * margin
|
|
||||||
height_mid = img_height_model - 2 * margin
|
|
||||||
img = img / 255.
|
|
||||||
img_h = img.shape[0]
|
|
||||||
img_w = img.shape[1]
|
|
||||||
|
|
||||||
prediction_true = np.zeros((img_h, img_w, 3))
|
|
||||||
nxf = img_w / float(width_mid)
|
|
||||||
nyf = img_h / float(height_mid)
|
|
||||||
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
|
|
||||||
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
|
|
||||||
|
|
||||||
for i in range(nxf):
|
|
||||||
for j in range(nyf):
|
|
||||||
if i == 0:
|
|
||||||
index_x_d = i * width_mid
|
|
||||||
index_x_u = index_x_d + img_width_model
|
|
||||||
else:
|
|
||||||
index_x_d = i * width_mid
|
|
||||||
index_x_u = index_x_d + img_width_model
|
|
||||||
if j == 0:
|
|
||||||
index_y_d = j * height_mid
|
|
||||||
index_y_u = index_y_d + img_height_model
|
|
||||||
else:
|
|
||||||
index_y_d = j * height_mid
|
|
||||||
index_y_u = index_y_d + img_height_model
|
|
||||||
|
|
||||||
if index_x_u > img_w:
|
|
||||||
index_x_u = img_w
|
|
||||||
index_x_d = img_w - img_width_model
|
|
||||||
if index_y_u > img_h:
|
|
||||||
index_y_u = img_h
|
|
||||||
index_y_d = img_h - img_height_model
|
|
||||||
|
|
||||||
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
|
|
||||||
label_p_pred = self.model_zoo.get("enhancement").predict(img_patch, verbose=0)
|
|
||||||
seg = label_p_pred[0, :, :, :] * 255
|
|
||||||
|
|
||||||
if i == 0 and j == 0:
|
|
||||||
prediction_true[index_y_d + 0:index_y_u - margin,
|
|
||||||
index_x_d + 0:index_x_u - margin] = \
|
|
||||||
seg[0:-margin or None,
|
|
||||||
0:-margin or None]
|
|
||||||
elif i == nxf - 1 and j == nyf - 1:
|
|
||||||
prediction_true[index_y_d + margin:index_y_u - 0,
|
|
||||||
index_x_d + margin:index_x_u - 0] = \
|
|
||||||
seg[margin:,
|
|
||||||
margin:]
|
|
||||||
elif i == 0 and j == nyf - 1:
|
|
||||||
prediction_true[index_y_d + margin:index_y_u - 0,
|
|
||||||
index_x_d + 0:index_x_u - margin] = \
|
|
||||||
seg[margin:,
|
|
||||||
0:-margin or None]
|
|
||||||
elif i == nxf - 1 and j == 0:
|
|
||||||
prediction_true[index_y_d + 0:index_y_u - margin,
|
|
||||||
index_x_d + margin:index_x_u - 0] = \
|
|
||||||
seg[0:-margin or None,
|
|
||||||
margin:]
|
|
||||||
elif i == 0 and j != 0 and j != nyf - 1:
|
|
||||||
prediction_true[index_y_d + margin:index_y_u - margin,
|
|
||||||
index_x_d + 0:index_x_u - margin] = \
|
|
||||||
seg[margin:-margin or None,
|
|
||||||
0:-margin or None]
|
|
||||||
elif i == nxf - 1 and j != 0 and j != nyf - 1:
|
|
||||||
prediction_true[index_y_d + margin:index_y_u - margin,
|
|
||||||
index_x_d + margin:index_x_u - 0] = \
|
|
||||||
seg[margin:-margin or None,
|
|
||||||
margin:]
|
|
||||||
elif i != 0 and i != nxf - 1 and j == 0:
|
|
||||||
prediction_true[index_y_d + 0:index_y_u - margin,
|
|
||||||
index_x_d + margin:index_x_u - margin] = \
|
|
||||||
seg[0:-margin or None,
|
|
||||||
margin:-margin or None]
|
|
||||||
elif i != 0 and i != nxf - 1 and j == nyf - 1:
|
|
||||||
prediction_true[index_y_d + margin:index_y_u - 0,
|
|
||||||
index_x_d + margin:index_x_u - margin] = \
|
|
||||||
seg[margin:,
|
|
||||||
margin:-margin or None]
|
|
||||||
else:
|
|
||||||
prediction_true[index_y_d + margin:index_y_u - margin,
|
|
||||||
index_x_d + margin:index_x_u - margin] = \
|
|
||||||
seg[margin:-margin or None,
|
|
||||||
margin:-margin or None]
|
|
||||||
|
|
||||||
prediction_true = prediction_true.astype(int)
|
|
||||||
return prediction_true
|
|
||||||
|
|
||||||
def calculate_width_height_by_columns(self, img, num_col, width_early, label_p_pred):
|
def calculate_width_height_by_columns(self, img, num_col, width_early, label_p_pred):
|
||||||
self.logger.debug("enter calculate_width_height_by_columns")
|
self.logger.debug("enter calculate_width_height_by_columns")
|
||||||
if num_col == 1 and width_early < 1100:
|
if num_col == 1 and width_early < 1100:
|
||||||
|
|
@ -462,7 +366,9 @@ class Eynollah:
|
||||||
img_new, _ = self.calculate_width_height_by_columns(img, num_col, width_early, label_p_pred)
|
img_new, _ = self.calculate_width_height_by_columns(img, num_col, width_early, label_p_pred)
|
||||||
|
|
||||||
if img_new.shape[1] > img.shape[1]:
|
if img_new.shape[1] > img.shape[1]:
|
||||||
img_new = self.predict_enhancement(img_new)
|
img_new = self.do_prediction(True, img_new, self.model_zoo.get("enhancement"),
|
||||||
|
marginal_of_patch_percent=0,
|
||||||
|
is_enhancement=True)
|
||||||
is_image_enhanced = True
|
is_image_enhanced = True
|
||||||
|
|
||||||
return img, img_new, is_image_enhanced
|
return img, img_new, is_image_enhanced
|
||||||
|
|
@ -630,6 +536,7 @@ class Eynollah:
|
||||||
thresholding_for_artificial_class=False,
|
thresholding_for_artificial_class=False,
|
||||||
threshold_art_class=0.1,
|
threshold_art_class=0.1,
|
||||||
artificial_class=2,
|
artificial_class=2,
|
||||||
|
is_enhancement=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.logger.debug("enter do_prediction (patches=%d)", patches)
|
self.logger.debug("enter do_prediction (patches=%d)", patches)
|
||||||
|
|
@ -645,6 +552,9 @@ class Eynollah:
|
||||||
img = resize_image(img, img_height_model, img_width_model)
|
img = resize_image(img, img_height_model, img_width_model)
|
||||||
|
|
||||||
label_p_pred = model.predict(img[np.newaxis], verbose=0)[0]
|
label_p_pred = model.predict(img[np.newaxis], verbose=0)[0]
|
||||||
|
if is_enhancement:
|
||||||
|
seg = (label_p_pred * 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
seg = np.argmax(label_p_pred, axis=2)
|
seg = np.argmax(label_p_pred, axis=2)
|
||||||
|
|
||||||
if thresholding_for_artificial_class:
|
if thresholding_for_artificial_class:
|
||||||
|
|
@ -671,6 +581,9 @@ class Eynollah:
|
||||||
height_mid = img_height_model - 2 * margin
|
height_mid = img_height_model - 2 * margin
|
||||||
img_h = img.shape[0]
|
img_h = img.shape[0]
|
||||||
img_w = img.shape[1]
|
img_w = img.shape[1]
|
||||||
|
if is_enhancement:
|
||||||
|
prediction = np.zeros((img_h, img_w, 3), dtype=np.uint8)
|
||||||
|
else:
|
||||||
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
|
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
|
||||||
if thresholding_for_artificial_class:
|
if thresholding_for_artificial_class:
|
||||||
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
|
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
|
||||||
|
|
@ -715,6 +628,9 @@ class Eynollah:
|
||||||
i == nxf - 1 and j == nyf - 1):
|
i == nxf - 1 and j == nyf - 1):
|
||||||
self.logger.debug("predicting patches on %s", str(img_patch.shape))
|
self.logger.debug("predicting patches on %s", str(img_patch.shape))
|
||||||
label_p_pred = model.predict(img_patch, verbose=0)
|
label_p_pred = model.predict(img_patch, verbose=0)
|
||||||
|
if is_enhancement:
|
||||||
|
seg = (label_p_pred * 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
seg = np.argmax(label_p_pred, axis=3)
|
seg = np.argmax(label_p_pred, axis=3)
|
||||||
|
|
||||||
if thresholding_for_some_classes:
|
if thresholding_for_some_classes:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue