do_prediction: if img was too small for model, also upscale results

(i.e. resize back to match original size after prediction)
This commit is contained in:
Robert Sachunsky 2026-03-04 23:41:45 +01:00
parent 8ebbe65c17
commit 341480e9a0

View file

@ -635,11 +635,13 @@ class Eynollah:
self.logger.debug("enter do_prediction (patches=%d)", patches)
img_height_model = model.layers[-1].output_shape[1]
img_width_model = model.layers[-1].output_shape[2]
img_h_page = img.shape[0]
img_w_page = img.shape[1]
img = img / 255.
img = img.astype(np.float16)
if not patches:
img_h_page = img.shape[0]
img_w_page = img.shape[1]
img = img / 255.0
img = resize_image(img, img_height_model, img_width_model)
label_p_pred = model.predict(img[np.newaxis], verbose=0)[0]
@ -658,17 +660,15 @@ class Eynollah:
return resize_image(seg, img_h_page, img_w_page).astype(np.uint8)
if img.shape[0] < img_height_model:
if img_h_page < img_height_model:
img = resize_image(img, img_height_model, img.shape[1])
if img.shape[1] < img_width_model:
if img_w_page < img_width_model:
img = resize_image(img, img.shape[0], img_width_model)
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
margin = int(marginal_of_patch_percent * img_height_model)
width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin
img = img / 255.
#img = img.astype(np.float16)
img_h = img.shape[0]
img_w = img.shape[1]
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
@ -808,6 +808,10 @@ class Eynollah:
only=True,
skeletonize=True,
dilate=3)
if img_h != img_h_page or img_w != img_w_page:
prediction = resize_image(prediction, img_h_page, img_w_page)
gc.collect()
return prediction
@ -1073,12 +1077,10 @@ class Eynollah:
model_region = self.model_zoo.get("region_fl") if patches else self.model_zoo.get("region_fl_np")
thresholding_for_heading = True
img = otsu_copy_binary(img).astype(np.uint8)
if not patches:
img = otsu_copy_binary(img).astype(np.uint8)
prediction_regions = None
thresholding_for_heading = False
elif cols:
img = otsu_copy_binary(img).astype(np.uint8)
if cols == 1:
img = resize_image(img, int(img_height_h * 1000 / float(img_width_h)), 1000).astype(np.uint8)
elif cols == 2: