diff --git a/qurator/eynollah/cli.py b/qurator/eynollah/cli.py index b0f55cd..357582c 100644 --- a/qurator/eynollah/cli.py +++ b/qurator/eynollah/cli.py @@ -191,6 +191,16 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i is_flag=True, help="if this parameter set to true, this tool will try to do ocr", ) +@click.option( + "--num_col_upper", + "-ncu", + help="lower limit of columns in document image", +) +@click.option( + "--num_col_lower", + "-ncl", + help="upper limit of columns in document image", +) @click.option( "--log_level", "-l", @@ -198,7 +208,7 @@ def machine_based_reading_order(dir_xml, dir_out_modal_image, dir_out_classes, i help="Override log level globally to this", ) -def layout(image, out, dir_in, model, save_images, save_layout, save_deskewed, save_all, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, ignore_page_extraction, log_level): +def layout(image, out, dir_in, model, save_images, save_layout, save_deskewed, save_all, save_page, enable_plotting, allow_enhancement, curved_line, textline_light, full_layout, tables, right2left, input_binary, allow_scaling, headers_off, light_version, reading_order_machine_based, do_ocr, num_col_upper, num_col_lower, ignore_page_extraction, log_level): if log_level: setOverrideLogLevel(log_level) initLogging() @@ -235,6 +245,8 @@ def layout(image, out, dir_in, model, save_images, save_layout, save_deskewed, s ignore_page_extraction=ignore_page_extraction, reading_order_machine_based=reading_order_machine_based, do_ocr=do_ocr, + num_col_upper=num_col_upper, + num_col_lower=num_col_lower, ) if dir_in: eynollah.run() diff --git a/qurator/eynollah/eynollah.py b/qurator/eynollah/eynollah.py index 569aec5..f76dce8 100644 --- a/qurator/eynollah/eynollah.py +++ b/qurator/eynollah/eynollah.py @@ -178,6 +178,8 @@ class Eynollah: ignore_page_extraction=False, reading_order_machine_based=False, do_ocr=False, + num_col_upper=None, + num_col_lower=None, override_dpi=None, logger=None, pcgts=None, @@ -212,6 +214,14 @@ class Eynollah: self.headers_off = headers_off self.ignore_page_extraction = ignore_page_extraction self.ocr = do_ocr + if num_col_upper: + self.num_col_upper = int(num_col_upper) + else: + self.num_col_upper = num_col_upper + if num_col_lower: + self.num_col_lower = int(num_col_lower) + else: + self.num_col_lower = num_col_lower self.pcgts = pcgts if not dir_in: self.plotter = None if not enable_plotting else EynollahPlotter( @@ -597,36 +607,80 @@ class Eynollah: else: img = self.imread() img_bin = None - + + width_early = img.shape[1] t1 = time.time() _, page_coord = self.early_page_for_num_of_column_classification(img_bin) if not self.dir_in: model_num_classifier, session_col_classifier = self.start_new_session_and_model(self.model_dir_of_col_classifier) - if self.input_binary: - img_in = np.copy(img) - width_early = img_in.shape[1] - img_in = img_in / 255.0 - img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST) - img_in = img_in.reshape(1, 448, 448, 3) - else: - img_1ch = self.imread(grayscale=True) - width_early = img_1ch.shape[1] - img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] + if self.num_col_upper and not self.num_col_lower: + num_col = self.num_col_upper + label_p_pred = [np.ones(6)] + elif self.num_col_lower and not self.num_col_upper: + num_col = self.num_col_lower + label_p_pred = [np.ones(6)] + + elif (not self.num_col_upper and not self.num_col_lower): + if self.input_binary: + img_in = np.copy(img) + img_in = img_in / 255.0 + img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST) + img_in = img_in.reshape(1, 448, 448, 3) + else: + img_1ch = self.imread(grayscale=True) + width_early = img_1ch.shape[1] + img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] - img_1ch = img_1ch / 255.0 - img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST) - img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3)) - img_in[0, :, :, 0] = img_1ch[:, :] - img_in[0, :, :, 1] = img_1ch[:, :] - img_in[0, :, :, 2] = img_1ch[:, :] + img_1ch = img_1ch / 255.0 + img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST) + img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3)) + img_in[0, :, :, 0] = img_1ch[:, :] + img_in[0, :, :, 1] = img_1ch[:, :] + img_in[0, :, :, 2] = img_1ch[:, :] - if self.dir_in: - label_p_pred = self.model_classifier.predict(img_in, verbose=0) + if self.dir_in: + label_p_pred = self.model_classifier.predict(img_in, verbose=0) + else: + label_p_pred = model_num_classifier.predict(img_in, verbose=0) + num_col = np.argmax(label_p_pred[0]) + 1 + elif (self.num_col_upper and self.num_col_lower) and (self.num_col_upper!=self.num_col_lower): + if self.input_binary: + img_in = np.copy(img) + img_in = img_in / 255.0 + img_in = cv2.resize(img_in, (448, 448), interpolation=cv2.INTER_NEAREST) + img_in = img_in.reshape(1, 448, 448, 3) + else: + img_1ch = self.imread(grayscale=True) + width_early = img_1ch.shape[1] + img_1ch = img_1ch[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] + + img_1ch = img_1ch / 255.0 + img_1ch = cv2.resize(img_1ch, (448, 448), interpolation=cv2.INTER_NEAREST) + img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3)) + img_in[0, :, :, 0] = img_1ch[:, :] + img_in[0, :, :, 1] = img_1ch[:, :] + img_in[0, :, :, 2] = img_1ch[:, :] + + + if self.dir_in: + label_p_pred = self.model_classifier.predict(img_in, verbose=0) + else: + label_p_pred = model_num_classifier.predict(img_in, verbose=0) + num_col = np.argmax(label_p_pred[0]) + 1 + + if num_col > self.num_col_upper: + num_col = self.num_col_upper + label_p_pred = [np.ones(6)] + if num_col < self.num_col_lower: + num_col = self.num_col_lower + label_p_pred = [np.ones(6)] + else: - label_p_pred = model_num_classifier.predict(img_in, verbose=0) - num_col = np.argmax(label_p_pred[0]) + 1 + num_col = self.num_col_upper + label_p_pred = [np.ones(6)] + self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5))