do_prediction*(): ensure always returns dtype=uint8

This commit is contained in:
Robert Sachunsky 2026-05-08 17:36:31 +02:00
parent 68a26a5c3f
commit 58afdf5e87

View file

@ -414,7 +414,7 @@ class Eynollah:
is_image_enhanced = False
self.logger.debug("exit resize_and_enhance_image_with_column_classifier")
image['img_res'] = img_res
image['img_res'] = img_res.astype(np.uint8)
image['scale_y'] = 1.0 * img_res.shape[0] / img.shape[0]
image['scale_x'] = 1.0 * img_res.shape[1] / img.shape[1]
return is_image_enhanced, num_col, is_image_resized
@ -447,7 +447,7 @@ class Eynollah:
if is_enhancement:
seg = (label_p_pred * 255).astype(np.uint8)
else:
seg = np.argmax(label_p_pred, axis=2)
seg = np.argmax(label_p_pred, axis=2).astype(np.uint8)
if thresholding_for_artificial_class:
seg_mask_label(
@ -460,7 +460,7 @@ class Eynollah:
seg, label_p_pred[:, :, heading_class] >= 0.2,
label=heading_class)
return resize_image(seg, img_h_page, img_w_page).astype(np.uint8)
return resize_image(seg, img_h_page, img_w_page)
if img_h_page < img_height_model:
img = resize_image(img, img_height_model, img.shape[1])
@ -567,7 +567,7 @@ class Eynollah:
if is_enhancement:
seg = (prediction * 255).astype(np.uint8)
else:
seg = np.argmax(prediction, axis=2)
seg = np.argmax(prediction, axis=2).astype(np.uint8)
if thresholding_for_some_classes:
seg_mask_label(
seg, prediction[:, :, 4] > 0.03,
@ -616,9 +616,9 @@ class Eynollah:
img = resize_image(img, img_height_model, img_width_model)
label_p_pred = model.predict(img[np.newaxis], verbose=0)[0]
seg = np.argmax(label_p_pred, axis=2)
seg = np.argmax(label_p_pred, axis=2).astype(np.uint8)
prediction = resize_image(seg, img_h_page, img_w_page).astype(np.uint8)
prediction = resize_image(seg, img_h_page, img_w_page)
if thresholding_for_artificial_class:
mask = resize_image(label_p_pred[:, :, artificial_class],
@ -741,7 +741,7 @@ class Eynollah:
img_patch[:] = 0
# decode
seg = np.argmax(prediction, axis=2)
seg = np.argmax(prediction, axis=2).astype(np.uint8)
conf = prediction[tuple(np.indices(seg.shape)) + (seg,)]
if thresholding_for_artificial_class:
seg_art = prediction[:, :, artificial_class] >= threshold_art_class
@ -1082,10 +1082,9 @@ class Eynollah:
mask_seps_only = (prediction_regions == label_seps).astype('uint8')
mask_tabs_only = prediction_tables
##if num_col_classifier == 1 or num_col_classifier == 2:
###mask_texts_only = cv2.erode(mask_texts_only, KERNEL, iterations=1)
##mask_texts_only = cv2.dilate(mask_texts_only, KERNEL, iterations=1)
mask_texts_only = cv2.dilate(mask_texts_only, kernel=np.ones((2,2), np.uint8), iterations=1)
# if num_col_classifier == 1 or num_col_classifier == 2:
# mask_texts_only = cv2.morphologyEx(mask_texts_only, cv2.MORPH_OPEN, KERNEL, iterations=1)
mask_texts_only = cv2.dilate(mask_texts_only, kernel=np.ones((2, 2), np.uint8), iterations=1)
polygons_seplines, hir_seplines = return_contours_of_image(mask_seps_only)
polygons_seplines = filter_contours_area_of_image(