use (generalized) do_prediction() instead of predict_enhancement()

This commit is contained in:
Robert Sachunsky 2026-03-04 23:49:11 +01:00
parent 341480e9a0
commit 41dccb216c
2 changed files with 17 additions and 102 deletions

View file

@ -187,7 +187,6 @@ def layout_cli(
assert enable_plotting or not save_all, "Plotting with -sa also requires -ep"
assert enable_plotting or not save_page, "Plotting with -sp also requires -ep"
assert enable_plotting or not save_images, "Plotting with -si also requires -ep"
assert enable_plotting or not allow_enhancement, "Plotting with -ae also requires -ep"
assert not enable_plotting or save_layout or save_deskewed or save_all or save_page or save_images or allow_enhancement, \
"Plotting with -ep also requires -sl, -sd, -sa, -sp, -si or -ae"
assert bool(image) != bool(dir_in), "Either -i (single input) or -di (directory) must be provided, but not both."

View file

@ -192,6 +192,7 @@ class Eynollah:
loadable = [
"col_classifier",
"binarization",
#"enhancement",
"page",
"region"
]
@ -256,103 +257,6 @@ class Eynollah:
key += '_uint8'
return self._imgs[key].copy()
def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement")
img_height_model = self.model_zoo.get("enhancement").layers[-1].output_shape[1]
img_width_model = self.model_zoo.get("enhancement").layers[-1].output_shape[2]
if img.shape[0] < img_height_model:
img = cv2.resize(img, (img.shape[1], img_width_model), interpolation=cv2.INTER_NEAREST)
if img.shape[1] < img_width_model:
img = cv2.resize(img, (img_height_model, img.shape[0]), interpolation=cv2.INTER_NEAREST)
margin = int(0 * img_width_model)
width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin
img = img / 255.
img_h = img.shape[0]
img_w = img.shape[1]
prediction_true = np.zeros((img_h, img_w, 3))
nxf = img_w / float(width_mid)
nyf = img_h / float(height_mid)
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)
for i in range(nxf):
for j in range(nyf):
if i == 0:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
else:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
if j == 0:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
else:
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_x_u > img_w:
index_x_u = img_w
index_x_d = img_w - img_width_model
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - img_height_model
img_patch = img[np.newaxis, index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model_zoo.get("enhancement").predict(img_patch, verbose=0)
seg = label_p_pred[0, :, :, :] * 255
if i == 0 and j == 0:
prediction_true[index_y_d + 0:index_y_u - margin,
index_x_d + 0:index_x_u - margin] = \
seg[0:-margin or None,
0:-margin or None]
elif i == nxf - 1 and j == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - 0,
index_x_d + margin:index_x_u - 0] = \
seg[margin:,
margin:]
elif i == 0 and j == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - 0,
index_x_d + 0:index_x_u - margin] = \
seg[margin:,
0:-margin or None]
elif i == nxf - 1 and j == 0:
prediction_true[index_y_d + 0:index_y_u - margin,
index_x_d + margin:index_x_u - 0] = \
seg[0:-margin or None,
margin:]
elif i == 0 and j != 0 and j != nyf - 1:
prediction_true[index_y_d + margin:index_y_u - margin,
index_x_d + 0:index_x_u - margin] = \
seg[margin:-margin or None,
0:-margin or None]
elif i == nxf - 1 and j != 0 and j != nyf - 1:
prediction_true[index_y_d + margin:index_y_u - margin,
index_x_d + margin:index_x_u - 0] = \
seg[margin:-margin or None,
margin:]
elif i != 0 and i != nxf - 1 and j == 0:
prediction_true[index_y_d + 0:index_y_u - margin,
index_x_d + margin:index_x_u - margin] = \
seg[0:-margin or None,
margin:-margin or None]
elif i != 0 and i != nxf - 1 and j == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - 0,
index_x_d + margin:index_x_u - margin] = \
seg[margin:,
margin:-margin or None]
else:
prediction_true[index_y_d + margin:index_y_u - margin,
index_x_d + margin:index_x_u - margin] = \
seg[margin:-margin or None,
margin:-margin or None]
prediction_true = prediction_true.astype(int)
return prediction_true
def calculate_width_height_by_columns(self, img, num_col, width_early, label_p_pred):
self.logger.debug("enter calculate_width_height_by_columns")
if num_col == 1 and width_early < 1100:
@ -462,7 +366,9 @@ class Eynollah:
img_new, _ = self.calculate_width_height_by_columns(img, num_col, width_early, label_p_pred)
if img_new.shape[1] > img.shape[1]:
img_new = self.predict_enhancement(img_new)
img_new = self.do_prediction(True, img_new, self.model_zoo.get("enhancement"),
marginal_of_patch_percent=0,
is_enhancement=True)
is_image_enhanced = True
return img, img_new, is_image_enhanced
@ -630,6 +536,7 @@ class Eynollah:
thresholding_for_artificial_class=False,
threshold_art_class=0.1,
artificial_class=2,
is_enhancement=False,
):
self.logger.debug("enter do_prediction (patches=%d)", patches)
@ -645,7 +552,10 @@ class Eynollah:
img = resize_image(img, img_height_model, img_width_model)
label_p_pred = model.predict(img[np.newaxis], verbose=0)[0]
seg = np.argmax(label_p_pred, axis=2)
if is_enhancement:
seg = (label_p_pred * 255).astype(np.uint8)
else:
seg = np.argmax(label_p_pred, axis=2)
if thresholding_for_artificial_class:
seg_mask_label(
@ -671,7 +581,10 @@ class Eynollah:
height_mid = img_height_model - 2 * margin
img_h = img.shape[0]
img_w = img.shape[1]
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
if is_enhancement:
prediction = np.zeros((img_h, img_w, 3), dtype=np.uint8)
else:
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
if thresholding_for_artificial_class:
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
nxf = math.ceil(img_w / float(width_mid))
@ -715,7 +628,10 @@ class Eynollah:
i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape))
label_p_pred = model.predict(img_patch, verbose=0)
seg = np.argmax(label_p_pred, axis=3)
if is_enhancement:
seg = (label_p_pred * 255).astype(np.uint8)
else:
seg = np.argmax(label_p_pred, axis=3)
if thresholding_for_some_classes:
seg_mask_label(