From 341480e9a050a1a728907ddacf98975e7203f468 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Wed, 4 Mar 2026 23:41:45 +0100 Subject: [PATCH] do_prediction: if img was too small for model, also upscale results (i.e. resize back to match original size after prediction) --- src/eynollah/eynollah.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index be4d2c7..2b12f67 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -635,11 +635,13 @@ class Eynollah: self.logger.debug("enter do_prediction (patches=%d)", patches) img_height_model = model.layers[-1].output_shape[1] img_width_model = model.layers[-1].output_shape[2] + img_h_page = img.shape[0] + img_w_page = img.shape[1] + + img = img / 255. + img = img.astype(np.float16) if not patches: - img_h_page = img.shape[0] - img_w_page = img.shape[1] - img = img / 255.0 img = resize_image(img, img_height_model, img_width_model) label_p_pred = model.predict(img[np.newaxis], verbose=0)[0] @@ -658,17 +660,15 @@ class Eynollah: return resize_image(seg, img_h_page, img_w_page).astype(np.uint8) - if img.shape[0] < img_height_model: + if img_h_page < img_height_model: img = resize_image(img, img_height_model, img.shape[1]) - if img.shape[1] < img_width_model: + if img_w_page < img_width_model: img = resize_image(img, img.shape[0], img_width_model) self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model) margin = int(marginal_of_patch_percent * img_height_model) width_mid = img_width_model - 2 * margin height_mid = img_height_model - 2 * margin - img = img / 255. - #img = img.astype(np.float16) img_h = img.shape[0] img_w = img.shape[1] prediction = np.zeros((img_h, img_w), dtype=np.uint8) @@ -808,6 +808,10 @@ class Eynollah: only=True, skeletonize=True, dilate=3) + + if img_h != img_h_page or img_w != img_w_page: + prediction = resize_image(prediction, img_h_page, img_w_page) + gc.collect() return prediction @@ -1073,12 +1077,10 @@ class Eynollah: model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np") thresholding_for_heading = True + img = otsu_copy_binary(img).astype(np.uint8) if not patches: - img = otsu_copy_binary(img).astype(np.uint8) - prediction_regions = None thresholding_for_heading = False elif cols: - img = otsu_copy_binary(img).astype(np.uint8) if cols == 1: img = resize_image(img, int(img_height_h * 1000 / float(img_width_h)), 1000).astype(np.uint8) elif cols == 2: