diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index ae182dc..c632941 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -309,6 +309,7 @@ class Eynollah: return img_new, img_is_resized + # FIXME: actually may run enhancement model, should be renamed def resize_image_with_column_classifier(self, image): self.logger.debug("enter resize_image_with_column_classifier") img = self.imread(image, binary=self.input_binary) @@ -316,19 +317,36 @@ class Eynollah: width_early = img.shape[1] page_img, page_coord = self.early_page_for_num_of_column_classification(img) - if self.input_binary: - img_in = page_img - else: - img_1ch = self.imread(image, grayscale=True, uint8=False) - img_1ch = img_1ch[page_coord[0]: page_coord[1], - page_coord[2]: page_coord[3]] - img_in = np.repeat(img_1ch[:, :, np.newaxis], 3, axis=2) - img_in = img_in / 255.0 - img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST).astype(np.float16) + label_p_pred = np.ones(6) + conf_col = 1.0 + if self.num_col_upper and not self.num_col_lower: + num_col = self.num_col_upper + elif self.num_col_lower and not self.num_col_upper: + num_col = self.num_col_lower + elif (not self.num_col_upper and not self.num_col_lower or + self.num_col_upper != self.num_col_lower): + if self.input_binary: + img_in = page_img + else: + img_1ch = self.imread(image, grayscale=True) + img_1ch = img_1ch[page_coord[0]: page_coord[1], + page_coord[2]: page_coord[3]] + img_in = np.repeat(img_1ch[:, :, np.newaxis], 3, axis=2) + img_in = img_in / 255.0 + img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST).astype(np.float16) - label_p_pred = self.model_zoo.get("col_classifier").predict(img_in[np.newaxis], verbose=0)[0] - num_col = np.argmax(label_p_pred) + 1 - conf_col = np.max(label_p_pred) + label_p_pred = self.model_zoo.get("col_classifier").predict(img_in[np.newaxis], verbose=0)[0] + num_col = np.argmax(label_p_pred) + 1 + conf_col = np.max(label_p_pred) + if self.num_col_upper and self.num_col_upper < num_col: + num_col = self.num_col_upper + conf_col = 1.0 + if self.num_col_lower and self.num_col_lower > num_col: + num_col = self.num_col_lower + conf_col = 1.0 + else: + num_col = self.num_col_upper + conf_col = 1.0 self.logger.info("Found %s columns (%s)", num_col, np.around(label_p_pred, decimals=5)) if num_col in (1, 2): @@ -349,6 +367,7 @@ class Eynollah: image['scale_x'] = 1.0 * img_new.shape[1] / img.shape[1] return + # FIXME: does not actually run enhancement model, should be renamed def resize_and_enhance_image_with_column_classifier(self, image): self.logger.debug("enter resize_and_enhance_image_with_column_classifier") dpi = image['dpi']