From 8f82e81551953f5cdd89268a6b4472dcdb20f7bc Mon Sep 17 00:00:00 2001 From: Konstantin Baierer Date: Fri, 5 Feb 2021 17:35:29 +0100 Subject: [PATCH] remove unnecessary patches assignment, simplify if-else --- sbb_newspapers_org_image/eynollah.py | 224 ++++++++++++--------------- 1 file changed, 102 insertions(+), 122 deletions(-) diff --git a/sbb_newspapers_org_image/eynollah.py b/sbb_newspapers_org_image/eynollah.py index 5bf67e8..4b2e5bc 100644 --- a/sbb_newspapers_org_image/eynollah.py +++ b/sbb_newspapers_org_image/eynollah.py @@ -499,7 +499,26 @@ class eynollah: img_width_model = model.layers[len(model.layers) - 1].output_shape[2] n_classes = model.layers[len(model.layers) - 1].output_shape[3] - if patches: + + if not patches: + img_h_page = img.shape[0] + img_w_page = img.shape[1] + img = img / float(255.0) + img = resize_image(img, img_height_model, img_width_model) + + label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2])) + + seg = np.argmax(label_p_pred, axis=3)[0] + seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + prediction_true = resize_image(seg_color, img_h_page, img_w_page) + prediction_true = prediction_true.astype(np.uint8) + + del img + del seg_color + del label_p_pred + del seg + + else: if img.shape[0] < img_height_model: img = resize_image(img, img_height_model, img.shape[1]) @@ -599,39 +618,18 @@ class eynollah: del seg_color del seg del img_patch - - if not patches: - img_h_page = img.shape[0] - img_w_page = img.shape[1] - img = img / float(255.0) - img = resize_image(img, img_height_model, img_width_model) - - label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2])) - - seg = np.argmax(label_p_pred, axis=3)[0] - seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) - prediction_true = resize_image(seg_color, img_h_page, img_w_page) - prediction_true = prediction_true.astype(np.uint8) - - del img - del seg_color - del label_p_pred - del seg - del model gc.collect() - return prediction_true def early_page_for_num_of_column_classification(self): self.logger.debug("enter early_page_for_num_of_column_classification") img = cv2.imread(self.image_filename) img = img.astype(np.uint8) - patches = False model_page, session_page = self.start_new_session_and_model(self.model_page_dir) for ii in range(1): img = cv2.GaussianBlur(img, (5, 5), 0) - img_page_prediction = self.do_prediction(patches, img, model_page) + img_page_prediction = self.do_prediction(False, img, model_page) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -664,12 +662,11 @@ class eynollah: def extract_page(self): self.logger.debug("enter extract_page") - patches = False model_page, session_page = self.start_new_session_and_model(self.model_page_dir) for ii in range(1): img = cv2.GaussianBlur(self.image, (5, 5), 0) - img_page_prediction = self.do_prediction(patches, img, model_page) + img_page_prediction = self.do_prediction(False, img, model_page) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) @@ -715,92 +712,88 @@ class eynollah: img_height_h = img.shape[0] img_width_h = img.shape[1] - if patches: - model_region, session_region = self.start_new_session_and_model(self.model_region_dir_fully) - if not patches: - model_region, session_region = self.start_new_session_and_model(self.model_region_dir_fully_np) - - if patches and cols == 1: - img2 = otsu_copy_binary(img) - img2 = img2.astype(np.uint8) - img2 = resize_image(img2, int(img_height_h * 0.7), int(img_width_h * 0.7)) - marginal_of_patch_percent = 0.1 - prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent) - prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) - - if patches and cols == 2: - img2 = otsu_copy_binary(img) - img2 = img2.astype(np.uint8) - img2 = resize_image(img2, int(img_height_h * 0.4), int(img_width_h * 0.4)) - marginal_of_patch_percent = 0.1 - prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent) - prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) - - elif patches and cols > 2: - img2 = otsu_copy_binary(img) - img2 = img2.astype(np.uint8) - img2 = resize_image(img2, int(img_height_h * 0.3), int(img_width_h * 0.3)) - marginal_of_patch_percent = 0.1 - prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent) - prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) - - if patches and cols == 2: - img = otsu_copy_binary(img) - img = img.astype(np.uint8) - if img_width_h >= 2000: - img = resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)) - img = img.astype(np.uint8) + model_region, session_region = self.start_new_session_and_model(self.model_region_dir_fully if patches else self.model_region_dir_fully_np) - if patches and cols == 1: + if not patches: img = otsu_copy_binary(img) img = img.astype(np.uint8) - img = resize_image(img, int(img_height_h * 0.5), int(img_width_h * 0.5)) - img = img.astype(np.uint8) - - if patches and cols == 3: - if (self.scale_x == 1 and img_width_h > 3000) or (self.scale_x != 1 and img_width_h > 2800): - img = otsu_copy_binary(img) - img = img.astype(np.uint8) - img = resize_image(img, int(img_height_h * 2800 / float(img_width_h)), 2800) - else: - img = otsu_copy_binary(img) - img = img.astype(np.uint8) - - if patches and cols == 4: - #print(self.scale_x,img_width_h,'scale') - if (self.scale_x == 1 and img_width_h > 4000) or (self.scale_x != 1 and img_width_h > 3700): - img = otsu_copy_binary(img) - img = img.astype(np.uint8) - img= resize_image(img, int(img_height_h * 3700 / float(img_width_h)), 3700) - else: - img = otsu_copy_binary(img)#self.otsu_copy(img) - img = img.astype(np.uint8) - img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)) - - if patches and cols==5: - if self.scale_x == 1 and img_width_h > 5000: + prediction_regions2 = None + else: + if cols == 1: + img2 = otsu_copy_binary(img) + img2 = img2.astype(np.uint8) + img2 = resize_image(img2, int(img_height_h * 0.7), int(img_width_h * 0.7)) + marginal_of_patch_percent = 0.1 + prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent) + prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) + + if cols == 2: + img2 = otsu_copy_binary(img) + img2 = img2.astype(np.uint8) + img2 = resize_image(img2, int(img_height_h * 0.4), int(img_width_h * 0.4)) + marginal_of_patch_percent = 0.1 + prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent) + prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) + + elif cols > 2: + img2 = otsu_copy_binary(img) + img2 = img2.astype(np.uint8) + img2 = resize_image(img2, int(img_height_h * 0.3), int(img_width_h * 0.3)) + marginal_of_patch_percent = 0.1 + prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent) + prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) + + if cols == 2: img = otsu_copy_binary(img) img = img.astype(np.uint8) - img= resize_image(img, int(img_height_h * 0.7), int(img_width_h * 0.7)) - else: - img = otsu_copy_binary(img) + if img_width_h >= 2000: + img = resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)) img = img.astype(np.uint8) - img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9) ) - if patches and cols>=6: - if img_width_h > 5600: + if cols == 1: img = otsu_copy_binary(img) img = img.astype(np.uint8) - img= resize_image(img, int(img_height_h * 5600 / float(img_width_h)), 5600) - else: - img = otsu_copy_binary(img) + img = resize_image(img, int(img_height_h * 0.5), int(img_width_h * 0.5)) img = img.astype(np.uint8) - img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)) - if not patches: - img = otsu_copy_binary(img) - img = img.astype(np.uint8) - prediction_regions2 = None + if cols == 3: + if (self.scale_x == 1 and img_width_h > 3000) or (self.scale_x != 1 and img_width_h > 2800): + img = otsu_copy_binary(img) + img = img.astype(np.uint8) + img = resize_image(img, int(img_height_h * 2800 / float(img_width_h)), 2800) + else: + img = otsu_copy_binary(img) + img = img.astype(np.uint8) + + if cols == 4: + if (self.scale_x == 1 and img_width_h > 4000) or (self.scale_x != 1 and img_width_h > 3700): + img = otsu_copy_binary(img) + img = img.astype(np.uint8) + img= resize_image(img, int(img_height_h * 3700 / float(img_width_h)), 3700) + else: + img = otsu_copy_binary(img)#self.otsu_copy(img) + img = img.astype(np.uint8) + img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)) + + if cols == 5: + if self.scale_x == 1 and img_width_h > 5000: + img = otsu_copy_binary(img) + img = img.astype(np.uint8) + img= resize_image(img, int(img_height_h * 0.7), int(img_width_h * 0.7)) + else: + img = otsu_copy_binary(img) + img = img.astype(np.uint8) + img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9) ) + + if cols >= 6: + if img_width_h > 5600: + img = otsu_copy_binary(img) + img = img.astype(np.uint8) + img= resize_image(img, int(img_height_h * 5600 / float(img_width_h)), 5600) + else: + img = otsu_copy_binary(img) + img = img.astype(np.uint8) + img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)) marginal_of_patch_percent = 0.1 prediction_regions = self.do_prediction(patches, img, model_region, marginal_of_patch_percent) @@ -1105,10 +1098,7 @@ class eynollah: def textline_contours(self, img, patches, scaler_h, scaler_w): self.logger.debug('enter textline_contours') - if patches: - model_textline, session_textline = self.start_new_session_and_model(self.model_textline_dir) - if not patches: - model_textline, session_textline = self.start_new_session_and_model(self.model_textline_dir_np) + model_textline, session_textline = self.start_new_session_and_model(self.model_textline_dir if patches else self.model_textline_dir_np) img = img.astype(np.uint8) img_org = np.copy(img) img_h = img_org.shape[0] @@ -1116,17 +1106,12 @@ class eynollah: img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w)) prediction_textline = self.do_prediction(patches, img, model_textline) prediction_textline = resize_image(prediction_textline, img_h, img_w) - patches = False - prediction_textline_longshot = self.do_prediction(patches, img, model_textline) + prediction_textline_longshot = self.do_prediction(False, img, model_textline) prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w) - - # prediction_textline_streched=self.do_prediction(patches,img,model_textline) - # prediction_textline_streched= resize_image(prediction_textline_streched, img_h, img_w) ##plt.imshow(prediction_textline_streched[:,:,0]) ##plt.show() session_textline.close() - del model_textline del session_textline del img @@ -1697,7 +1682,6 @@ class eynollah: model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens) gaussian_filter=False - patches=True binary=False ratio_y=1.3 ratio_x=1 @@ -1714,7 +1698,7 @@ class eynollah: img= cv2.GaussianBlur(img,(5,5),0) img = img.astype(np.uint16) - prediction_regions_org_y = self.do_prediction(patches,img,model_region) + prediction_regions_org_y = self.do_prediction(True, img, model_region) prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h ) #plt.imshow(prediction_regions_org_y[:,:,0]) @@ -1740,7 +1724,7 @@ class eynollah: img = cv2.GaussianBlur(img, (5,5 ), 0) img = img.astype(np.uint16) - prediction_regions_org = self.do_prediction(patches,img,model_region) + prediction_regions_org = self.do_prediction(True, img, model_region) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) ##plt.imshow(prediction_regions_org[:,:,0]) @@ -1757,7 +1741,6 @@ class eynollah: model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p2) gaussian_filter=False - patches=True binary=False ratio_x=1 ratio_y=1 @@ -1776,7 +1759,7 @@ class eynollah: img = img.astype(np.uint16) marginal_patch=0.2 - prediction_regions_org2=self.do_prediction(patches,img,model_region,marginal_patch) + prediction_regions_org2=self.do_prediction(True, img, model_region, marginal_patch) prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h ) @@ -2224,16 +2207,15 @@ class eynollah: num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_lines = self.run_graphics_and_columns(text_regions_p_1, num_column_is_classified) self.logger.info("Graphics detection took %ss ", str(time.time() - t1)) - #print(num_col, "num_colnum_col") if not num_col: self.logger.info("No columns detected, outputting an empty PAGE-XML") self.write_into_page_xml([], page_coord, self.dir_out, [], [], [], [], [], [], [], [], self.curved_line, [], []) self.logger.info("Job done in %ss", str(time.time() - t1)) return - patches = True + scaler_h_textline = 1 # 1.2#1.2 scaler_w_textline = 1 # 0.9#1 - textline_mask_tot_ea, textline_mask_tot_long_shot = self.textline_contours(image_page, patches, scaler_h_textline, scaler_w_textline) + textline_mask_tot_ea, textline_mask_tot_long_shot = self.textline_contours(image_page, True, scaler_h_textline, scaler_w_textline) K.clear_session() gc.collect() @@ -2354,11 +2336,10 @@ class eynollah: K.clear_session() # gc.collect() - patches = True image_page = image_page.astype(np.uint8) # print(type(image_page)) - regions_fully, regions_fully_only_drop = self.extract_text_regions(image_page, patches, cols=num_col_classifier) + regions_fully, regions_fully_only_drop = self.extract_text_regions(image_page, True, cols=num_col_classifier) text_regions_p[:,:][regions_fully[:,:,0]==6]=6 regions_fully_only_drop = put_drop_out_from_only_drop_model(regions_fully_only_drop, text_regions_p) @@ -2376,8 +2357,7 @@ class eynollah: K.clear_session() gc.collect() - patches = False - regions_fully_np, _ = self.extract_text_regions(image_page, patches, cols=num_col_classifier) + regions_fully_np, _ = self.extract_text_regions(image_page, False, cols=num_col_classifier) # plt.imshow(regions_fully_np[:,:,0]) # plt.show()