remove unnecessary patches assignment, simplify if-else

pull/19/head
Konstantin Baierer 4 years ago
parent 9dca742694
commit 8f82e81551

@ -499,7 +499,26 @@ class eynollah:
img_width_model = model.layers[len(model.layers) - 1].output_shape[2] img_width_model = model.layers[len(model.layers) - 1].output_shape[2]
n_classes = model.layers[len(model.layers) - 1].output_shape[3] n_classes = model.layers[len(model.layers) - 1].output_shape[3]
if patches:
if not patches:
img_h_page = img.shape[0]
img_w_page = img.shape[1]
img = img / float(255.0)
img = resize_image(img, img_height_model, img_width_model)
label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
seg = np.argmax(label_p_pred, axis=3)[0]
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = resize_image(seg_color, img_h_page, img_w_page)
prediction_true = prediction_true.astype(np.uint8)
del img
del seg_color
del label_p_pred
del seg
else:
if img.shape[0] < img_height_model: if img.shape[0] < img_height_model:
img = resize_image(img, img_height_model, img.shape[1]) img = resize_image(img, img_height_model, img.shape[1])
@ -599,39 +618,18 @@ class eynollah:
del seg_color del seg_color
del seg del seg
del img_patch del img_patch
if not patches:
img_h_page = img.shape[0]
img_w_page = img.shape[1]
img = img / float(255.0)
img = resize_image(img, img_height_model, img_width_model)
label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
seg = np.argmax(label_p_pred, axis=3)[0]
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = resize_image(seg_color, img_h_page, img_w_page)
prediction_true = prediction_true.astype(np.uint8)
del img
del seg_color
del label_p_pred
del seg
del model
gc.collect() gc.collect()
return prediction_true return prediction_true
def early_page_for_num_of_column_classification(self): def early_page_for_num_of_column_classification(self):
self.logger.debug("enter early_page_for_num_of_column_classification") self.logger.debug("enter early_page_for_num_of_column_classification")
img = cv2.imread(self.image_filename) img = cv2.imread(self.image_filename)
img = img.astype(np.uint8) img = img.astype(np.uint8)
patches = False
model_page, session_page = self.start_new_session_and_model(self.model_page_dir) model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
for ii in range(1): for ii in range(1):
img = cv2.GaussianBlur(img, (5, 5), 0) img = cv2.GaussianBlur(img, (5, 5), 0)
img_page_prediction = self.do_prediction(patches, img, model_page) img_page_prediction = self.do_prediction(False, img, model_page)
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0) _, thresh = cv2.threshold(imgray, 0, 255, 0)
@ -664,12 +662,11 @@ class eynollah:
def extract_page(self): def extract_page(self):
self.logger.debug("enter extract_page") self.logger.debug("enter extract_page")
patches = False
model_page, session_page = self.start_new_session_and_model(self.model_page_dir) model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
for ii in range(1): for ii in range(1):
img = cv2.GaussianBlur(self.image, (5, 5), 0) img = cv2.GaussianBlur(self.image, (5, 5), 0)
img_page_prediction = self.do_prediction(patches, img, model_page) img_page_prediction = self.do_prediction(False, img, model_page)
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0) _, thresh = cv2.threshold(imgray, 0, 255, 0)
@ -715,12 +712,14 @@ class eynollah:
img_height_h = img.shape[0] img_height_h = img.shape[0]
img_width_h = img.shape[1] img_width_h = img.shape[1]
if patches: model_region, session_region = self.start_new_session_and_model(self.model_region_dir_fully if patches else self.model_region_dir_fully_np)
model_region, session_region = self.start_new_session_and_model(self.model_region_dir_fully)
if not patches:
model_region, session_region = self.start_new_session_and_model(self.model_region_dir_fully_np)
if patches and cols == 1: if not patches:
img = otsu_copy_binary(img)
img = img.astype(np.uint8)
prediction_regions2 = None
else:
if cols == 1:
img2 = otsu_copy_binary(img) img2 = otsu_copy_binary(img)
img2 = img2.astype(np.uint8) img2 = img2.astype(np.uint8)
img2 = resize_image(img2, int(img_height_h * 0.7), int(img_width_h * 0.7)) img2 = resize_image(img2, int(img_height_h * 0.7), int(img_width_h * 0.7))
@ -728,7 +727,7 @@ class eynollah:
prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent) prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent)
prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h)
if patches and cols == 2: if cols == 2:
img2 = otsu_copy_binary(img) img2 = otsu_copy_binary(img)
img2 = img2.astype(np.uint8) img2 = img2.astype(np.uint8)
img2 = resize_image(img2, int(img_height_h * 0.4), int(img_width_h * 0.4)) img2 = resize_image(img2, int(img_height_h * 0.4), int(img_width_h * 0.4))
@ -736,7 +735,7 @@ class eynollah:
prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent) prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent)
prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h)
elif patches and cols > 2: elif cols > 2:
img2 = otsu_copy_binary(img) img2 = otsu_copy_binary(img)
img2 = img2.astype(np.uint8) img2 = img2.astype(np.uint8)
img2 = resize_image(img2, int(img_height_h * 0.3), int(img_width_h * 0.3)) img2 = resize_image(img2, int(img_height_h * 0.3), int(img_width_h * 0.3))
@ -744,20 +743,20 @@ class eynollah:
prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent) prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent)
prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h) prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h)
if patches and cols == 2: if cols == 2:
img = otsu_copy_binary(img) img = otsu_copy_binary(img)
img = img.astype(np.uint8) img = img.astype(np.uint8)
if img_width_h >= 2000: if img_width_h >= 2000:
img = resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)) img = resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9))
img = img.astype(np.uint8) img = img.astype(np.uint8)
if patches and cols == 1: if cols == 1:
img = otsu_copy_binary(img) img = otsu_copy_binary(img)
img = img.astype(np.uint8) img = img.astype(np.uint8)
img = resize_image(img, int(img_height_h * 0.5), int(img_width_h * 0.5)) img = resize_image(img, int(img_height_h * 0.5), int(img_width_h * 0.5))
img = img.astype(np.uint8) img = img.astype(np.uint8)
if patches and cols == 3: if cols == 3:
if (self.scale_x == 1 and img_width_h > 3000) or (self.scale_x != 1 and img_width_h > 2800): if (self.scale_x == 1 and img_width_h > 3000) or (self.scale_x != 1 and img_width_h > 2800):
img = otsu_copy_binary(img) img = otsu_copy_binary(img)
img = img.astype(np.uint8) img = img.astype(np.uint8)
@ -766,8 +765,7 @@ class eynollah:
img = otsu_copy_binary(img) img = otsu_copy_binary(img)
img = img.astype(np.uint8) img = img.astype(np.uint8)
if patches and cols == 4: if cols == 4:
#print(self.scale_x,img_width_h,'scale')
if (self.scale_x == 1 and img_width_h > 4000) or (self.scale_x != 1 and img_width_h > 3700): if (self.scale_x == 1 and img_width_h > 4000) or (self.scale_x != 1 and img_width_h > 3700):
img = otsu_copy_binary(img) img = otsu_copy_binary(img)
img = img.astype(np.uint8) img = img.astype(np.uint8)
@ -777,7 +775,7 @@ class eynollah:
img = img.astype(np.uint8) img = img.astype(np.uint8)
img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)) img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9))
if patches and cols==5: if cols == 5:
if self.scale_x == 1 and img_width_h > 5000: if self.scale_x == 1 and img_width_h > 5000:
img = otsu_copy_binary(img) img = otsu_copy_binary(img)
img = img.astype(np.uint8) img = img.astype(np.uint8)
@ -787,7 +785,7 @@ class eynollah:
img = img.astype(np.uint8) img = img.astype(np.uint8)
img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9) ) img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9) )
if patches and cols>=6: if cols >= 6:
if img_width_h > 5600: if img_width_h > 5600:
img = otsu_copy_binary(img) img = otsu_copy_binary(img)
img = img.astype(np.uint8) img = img.astype(np.uint8)
@ -797,11 +795,6 @@ class eynollah:
img = img.astype(np.uint8) img = img.astype(np.uint8)
img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9)) img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9))
if not patches:
img = otsu_copy_binary(img)
img = img.astype(np.uint8)
prediction_regions2 = None
marginal_of_patch_percent = 0.1 marginal_of_patch_percent = 0.1
prediction_regions = self.do_prediction(patches, img, model_region, marginal_of_patch_percent) prediction_regions = self.do_prediction(patches, img, model_region, marginal_of_patch_percent)
prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h) prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h)
@ -1105,10 +1098,7 @@ class eynollah:
def textline_contours(self, img, patches, scaler_h, scaler_w): def textline_contours(self, img, patches, scaler_h, scaler_w):
self.logger.debug('enter textline_contours') self.logger.debug('enter textline_contours')
if patches: model_textline, session_textline = self.start_new_session_and_model(self.model_textline_dir if patches else self.model_textline_dir_np)
model_textline, session_textline = self.start_new_session_and_model(self.model_textline_dir)
if not patches:
model_textline, session_textline = self.start_new_session_and_model(self.model_textline_dir_np)
img = img.astype(np.uint8) img = img.astype(np.uint8)
img_org = np.copy(img) img_org = np.copy(img)
img_h = img_org.shape[0] img_h = img_org.shape[0]
@ -1116,17 +1106,12 @@ class eynollah:
img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w)) img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w))
prediction_textline = self.do_prediction(patches, img, model_textline) prediction_textline = self.do_prediction(patches, img, model_textline)
prediction_textline = resize_image(prediction_textline, img_h, img_w) prediction_textline = resize_image(prediction_textline, img_h, img_w)
patches = False prediction_textline_longshot = self.do_prediction(False, img, model_textline)
prediction_textline_longshot = self.do_prediction(patches, img, model_textline)
prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w) prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w)
# prediction_textline_streched=self.do_prediction(patches,img,model_textline)
# prediction_textline_streched= resize_image(prediction_textline_streched, img_h, img_w)
##plt.imshow(prediction_textline_streched[:,:,0]) ##plt.imshow(prediction_textline_streched[:,:,0])
##plt.show() ##plt.show()
session_textline.close() session_textline.close()
del model_textline del model_textline
del session_textline del session_textline
del img del img
@ -1697,7 +1682,6 @@ class eynollah:
model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens) model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens)
gaussian_filter=False gaussian_filter=False
patches=True
binary=False binary=False
ratio_y=1.3 ratio_y=1.3
ratio_x=1 ratio_x=1
@ -1714,7 +1698,7 @@ class eynollah:
img= cv2.GaussianBlur(img,(5,5),0) img= cv2.GaussianBlur(img,(5,5),0)
img = img.astype(np.uint16) img = img.astype(np.uint16)
prediction_regions_org_y = self.do_prediction(patches,img,model_region) prediction_regions_org_y = self.do_prediction(True, img, model_region)
prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h ) prediction_regions_org_y = resize_image(prediction_regions_org_y, img_height_h, img_width_h )
#plt.imshow(prediction_regions_org_y[:,:,0]) #plt.imshow(prediction_regions_org_y[:,:,0])
@ -1740,7 +1724,7 @@ class eynollah:
img = cv2.GaussianBlur(img, (5,5 ), 0) img = cv2.GaussianBlur(img, (5,5 ), 0)
img = img.astype(np.uint16) img = img.astype(np.uint16)
prediction_regions_org = self.do_prediction(patches,img,model_region) prediction_regions_org = self.do_prediction(True, img, model_region)
prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h ) prediction_regions_org = resize_image(prediction_regions_org, img_height_h, img_width_h )
##plt.imshow(prediction_regions_org[:,:,0]) ##plt.imshow(prediction_regions_org[:,:,0])
@ -1757,7 +1741,6 @@ class eynollah:
model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p2) model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p2)
gaussian_filter=False gaussian_filter=False
patches=True
binary=False binary=False
ratio_x=1 ratio_x=1
ratio_y=1 ratio_y=1
@ -1776,7 +1759,7 @@ class eynollah:
img = img.astype(np.uint16) img = img.astype(np.uint16)
marginal_patch=0.2 marginal_patch=0.2
prediction_regions_org2=self.do_prediction(patches,img,model_region,marginal_patch) prediction_regions_org2=self.do_prediction(True, img, model_region, marginal_patch)
prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h ) prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h )
@ -2224,16 +2207,15 @@ class eynollah:
num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_lines = self.run_graphics_and_columns(text_regions_p_1, num_column_is_classified) num_col, num_col_classifier, img_only_regions, page_coord, image_page, mask_images, mask_lines = self.run_graphics_and_columns(text_regions_p_1, num_column_is_classified)
self.logger.info("Graphics detection took %ss ", str(time.time() - t1)) self.logger.info("Graphics detection took %ss ", str(time.time() - t1))
#print(num_col, "num_colnum_col")
if not num_col: if not num_col:
self.logger.info("No columns detected, outputting an empty PAGE-XML") self.logger.info("No columns detected, outputting an empty PAGE-XML")
self.write_into_page_xml([], page_coord, self.dir_out, [], [], [], [], [], [], [], [], self.curved_line, [], []) self.write_into_page_xml([], page_coord, self.dir_out, [], [], [], [], [], [], [], [], self.curved_line, [], [])
self.logger.info("Job done in %ss", str(time.time() - t1)) self.logger.info("Job done in %ss", str(time.time() - t1))
return return
patches = True
scaler_h_textline = 1 # 1.2#1.2 scaler_h_textline = 1 # 1.2#1.2
scaler_w_textline = 1 # 0.9#1 scaler_w_textline = 1 # 0.9#1
textline_mask_tot_ea, textline_mask_tot_long_shot = self.textline_contours(image_page, patches, scaler_h_textline, scaler_w_textline) textline_mask_tot_ea, textline_mask_tot_long_shot = self.textline_contours(image_page, True, scaler_h_textline, scaler_w_textline)
K.clear_session() K.clear_session()
gc.collect() gc.collect()
@ -2354,11 +2336,10 @@ class eynollah:
K.clear_session() K.clear_session()
# gc.collect() # gc.collect()
patches = True
image_page = image_page.astype(np.uint8) image_page = image_page.astype(np.uint8)
# print(type(image_page)) # print(type(image_page))
regions_fully, regions_fully_only_drop = self.extract_text_regions(image_page, patches, cols=num_col_classifier) regions_fully, regions_fully_only_drop = self.extract_text_regions(image_page, True, cols=num_col_classifier)
text_regions_p[:,:][regions_fully[:,:,0]==6]=6 text_regions_p[:,:][regions_fully[:,:,0]==6]=6
regions_fully_only_drop = put_drop_out_from_only_drop_model(regions_fully_only_drop, text_regions_p) regions_fully_only_drop = put_drop_out_from_only_drop_model(regions_fully_only_drop, text_regions_p)
@ -2376,8 +2357,7 @@ class eynollah:
K.clear_session() K.clear_session()
gc.collect() gc.collect()
patches = False regions_fully_np, _ = self.extract_text_regions(image_page, False, cols=num_col_classifier)
regions_fully_np, _ = self.extract_text_regions(image_page, patches, cols=num_col_classifier)
# plt.imshow(regions_fully_np[:,:,0]) # plt.imshow(regions_fully_np[:,:,0])
# plt.show() # plt.show()

Loading…
Cancel
Save