mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-04-14 19:31:57 +02:00
do_prediction* for "col_classifier": pass array as float16 instead of float64
This commit is contained in:
parent
f54deff452
commit
67e9f84b54
1 changed files with 4 additions and 5 deletions
|
|
@ -327,7 +327,7 @@ class Eynollah:
|
|||
page_coord[2]: page_coord[3]]
|
||||
img_in = np.repeat(img_1ch[:, :, np.newaxis], 3, axis=2)
|
||||
img_in = img_in / 255.0
|
||||
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
|
||||
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST).astype(np.float16)
|
||||
|
||||
label_p_pred = self.model_zoo.get("col_classifier").predict(img_in[np.newaxis], verbose=0)[0]
|
||||
num_col = np.argmax(label_p_pred) + 1
|
||||
|
|
@ -386,7 +386,7 @@ class Eynollah:
|
|||
image['coord_page'][2]: image['coord_page'][3]]
|
||||
img_in = np.repeat(img_1ch[:, :, np.newaxis], 3, axis=2)
|
||||
img_in = img_in / 255.0
|
||||
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST)
|
||||
img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST).astype(np.float16)
|
||||
|
||||
label_p_pred = self.model_zoo.get("col_classifier").predict(img_in[np.newaxis], verbose=0)[0]
|
||||
num_col = np.argmax(label_p_pred) + 1
|
||||
|
|
@ -895,7 +895,7 @@ class Eynollah:
|
|||
return cropped_page, page_coord
|
||||
|
||||
def extract_text_regions_new(self, img, patches, cols):
|
||||
self.logger.debug("enter extract_text_regions")
|
||||
self.logger.debug("enter extract_text_regions_new")
|
||||
img_height_h = img.shape[0]
|
||||
img_width_h = img.shape[1]
|
||||
|
||||
|
|
@ -927,10 +927,9 @@ class Eynollah:
|
|||
else:
|
||||
prediction_regions = self.do_prediction(
|
||||
False, img, self.model_zoo.get("region_fl_np"),
|
||||
n_batch_inference=2,
|
||||
thresholding_for_heading=False)
|
||||
prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h)
|
||||
self.logger.debug("exit extract_text_regions")
|
||||
self.logger.debug("exit extract_text_regions_new")
|
||||
return prediction_regions
|
||||
|
||||
def extract_text_regions(self, img, patches, cols):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue