From 41dccb216c8fc304cecb2cbd2ba0f790b5a0faae Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 4 Mar 2026 23:49:11 +0100 Subject: [PATCH] use (generalized) `do_prediction()` instead of `predict_enhancement()` --- src/eynollah/cli/cli_layout.py | 1 - src/eynollah/eynollah.py | 118 +++++---------------------------- 2 files changed, 17 insertions(+), 102 deletions(-) diff --git a/src/eynollah/cli/cli_layout.py b/src/eynollah/cli/cli_layout.py index df66993..9d9f325 100644 --- a/src/eynollah/cli/cli_layout.py +++ b/src/eynollah/cli/cli_layout.py @@ -187,7 +187,6 @@ def layout_cli( assert enable_plotting or not save_all, "Plotting with -sa also requires -ep" assert enable_plotting or not save_page, "Plotting with -sp also requires -ep" assert enable_plotting or not save_images, "Plotting with -si also requires -ep" - assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep" assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \ "Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae" assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both." diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 2b12f67..dcdc642 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -192,6 +192,7 @@ class Eynollah: loadable = [ "col_classifier", "binarization", + #"enhancement", "page", "region" ] @@ -256,103 +257,6 @@ class Eynollah: key += '_uint8' return self._imgs[key].copy() - def predict_enhancement(self, img): - self.logger.debug("enter predict_enhancement") - - img_height_model = self.model_zoo.get("enhancement").layers[-1].output_shape[1] - img_width_model = self.model_zoo.get("enhancement").layers[-1].output_shape[2] - if img.shape[0] < img_height_model: - img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST) - if img.shape[1] < img_width_model: - img = cv2.resize(img, (img_height_model, img.shape[0]), interpolation=cv2.INTER_NEAREST) - margin = int(0 * img_width_model) - width_mid = img_width_model - 2 * margin - height_mid = img_height_model - 2 * margin - img = img / 255. - img_h = img.shape[0] - img_w = img.shape[1] - - prediction_true = np.zeros((img_h, img_w, 3)) - nxf = img_w / float(width_mid) - nyf = img_h / float(height_mid) - nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf) - nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf) - - for i in range(nxf): - for j in range(nyf): - if i == 0: - index_x_d = i * width_mid - index_x_u = index_x_d + img_width_model - else: - index_x_d = i * width_mid - index_x_u = index_x_d + img_width_model - if j == 0: - index_y_d = j * height_mid - index_y_u = index_y_d + img_height_model - else: - index_y_d = j * height_mid - index_y_u = index_y_d + img_height_model - - if index_x_u > img_w: - index_x_u = img_w - index_x_d = img_w - img_width_model - if index_y_u > img_h: - index_y_u = img_h - index_y_d = img_h - img_height_model - - img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :] - label_p_pred = self.model_zoo.get("enhancement").predict(img_patch, verbose=0) - seg = label_p_pred[0, :, :, :] * 255 - - if i == 0 and j == 0: - prediction_true[index_y_d + 0:index_y_u - margin, - index_x_d + 0:index_x_u - margin] = \ - seg[0:-margin or None, - 0:-margin or None] - elif i == nxf - 1 and j == nyf - 1: - prediction_true[index_y_d + margin:index_y_u - 0, - index_x_d + margin:index_x_u - 0] = \ - seg[margin:, - margin:] - elif i == 0 and j == nyf - 1: - prediction_true[index_y_d + margin:index_y_u - 0, - index_x_d + 0:index_x_u - margin] = \ - seg[margin:, - 0:-margin or None] - elif i == nxf - 1 and j == 0: - prediction_true[index_y_d + 0:index_y_u - margin, - index_x_d + margin:index_x_u - 0] = \ - seg[0:-margin or None, - margin:] - elif i == 0 and j != 0 and j != nyf - 1: - prediction_true[index_y_d + margin:index_y_u - margin, - index_x_d + 0:index_x_u - margin] = \ - seg[margin:-margin or None, - 0:-margin or None] - elif i == nxf - 1 and j != 0 and j != nyf - 1: - prediction_true[index_y_d + margin:index_y_u - margin, - index_x_d + margin:index_x_u - 0] = \ - seg[margin:-margin or None, - margin:] - elif i != 0 and i != nxf - 1 and j == 0: - prediction_true[index_y_d + 0:index_y_u - margin, - index_x_d + margin:index_x_u - margin] = \ - seg[0:-margin or None, - margin:-margin or None] - elif i != 0 and i != nxf - 1 and j == nyf - 1: - prediction_true[index_y_d + margin:index_y_u - 0, - index_x_d + margin:index_x_u - margin] = \ - seg[margin:, - margin:-margin or None] - else: - prediction_true[index_y_d + margin:index_y_u - margin, - index_x_d + margin:index_x_u - margin] = \ - seg[margin:-margin or None, - margin:-margin or None] - - prediction_true = prediction_true.astype(int) - return prediction_true - def calculate_width_height_by_columns(self, img, num_col, width_early, label_p_pred): self.logger.debug("enter calculate_width_height_by_columns") if num_col == 1 and width_early < 1100: @@ -462,7 +366,9 @@ class Eynollah: img_new, _ = self.calculate_width_height_by_columns(img, num_col, width_early, label_p_pred) if img_new.shape[1] > img.shape[1]: - img_new = self.predict_enhancement(img_new) + img_new = self.do_prediction(True, img_new, self.model_zoo.get("enhancement"), + marginal_of_patch_percent=0, + is_enhancement=True) is_image_enhanced = True return img, img_new, is_image_enhanced @@ -630,6 +536,7 @@ class Eynollah: thresholding_for_artificial_class=False, threshold_art_class=0.1, artificial_class=2, + is_enhancement=False, ): self.logger.debug("enter do_prediction (patches=%d)", patches) @@ -645,7 +552,10 @@ class Eynollah: img = resize_image(img, img_height_model, img_width_model) label_p_pred = model.predict(img[np.newaxis], verbose=0)[0] - seg = np.argmax(label_p_pred, axis=2) + if is_enhancement: + seg = (label_p_pred * 255).astype(np.uint8) + else: + seg = np.argmax(label_p_pred, axis=2) if thresholding_for_artificial_class: seg_mask_label( @@ -671,7 +581,10 @@ class Eynollah: height_mid = img_height_model - 2 * margin img_h = img.shape[0] img_w = img.shape[1] - prediction = np.zeros((img_h, img_w), dtype=np.uint8) + if is_enhancement: + prediction = np.zeros((img_h, img_w, 3), dtype=np.uint8) + else: + prediction = np.zeros((img_h, img_w), dtype=np.uint8) if thresholding_for_artificial_class: mask_artificial_class = np.zeros((img_h, img_w), dtype=bool) nxf = math.ceil(img_w / float(width_mid)) @@ -715,7 +628,10 @@ class Eynollah: i == nxf - 1 and j == nyf - 1): self.logger.debug("predicting patches on %s", str(img_patch.shape)) label_p_pred = model.predict(img_patch, verbose=0) - seg = np.argmax(label_p_pred, axis=3) + if is_enhancement: + seg = (label_p_pred * 255).astype(np.uint8) + else: + seg = np.argmax(label_p_pred, axis=3) if thresholding_for_some_classes: seg_mask_label(