do_prediction* for "col_classifier": pass array as float16 instead of float64

This commit is contained in:
Robert Sachunsky 2026-03-15 03:20:39 +01:00
parent f54deff452
commit 67e9f84b54

View file

@ -327,7 +327,7 @@ class Eynollah:
page_coord[2]: page_coord[3]]
img_in = np.repeat(img_1ch[:, :, np.newaxis], 3, axis=2)
img_in = img_in / 255.0
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST).astype(np.float16)
label_p_pred = self.model_zoo.get("col_classifier").predict(img_in[np.newaxis], verbose=0)[0]
num_col = np.argmax(label_p_pred) + 1
@ -386,7 +386,7 @@ class Eynollah:
image['coord_page'][2]: image['coord_page'][3]]
img_in = np.repeat(img_1ch[:, :, np.newaxis], 3, axis=2)
img_in = img_in / 255.0
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST).astype(np.float16)
label_p_pred = self.model_zoo.get("col_classifier").predict(img_in[np.newaxis], verbose=0)[0]
num_col = np.argmax(label_p_pred) + 1
@ -895,7 +895,7 @@ class Eynollah:
return cropped_page, page_coord
def extract_text_regions_new(self, img, patches, cols):
self.logger.debug("enter extract_text_regions")
self.logger.debug("enter extract_text_regions_new")
img_height_h = img.shape[0]
img_width_h = img.shape[1]
@ -927,10 +927,9 @@ class Eynollah:
else:
prediction_regions = self.do_prediction(
False, img, self.model_zoo.get("region_fl_np"),
n_batch_inference=2,
thresholding_for_heading=False)
prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h)
self.logger.debug("exit extract_text_regions")
self.logger.debug("exit extract_text_regions_new")
return prediction_regions
def extract_text_regions(self, img, patches, cols):