new table detection model is integrated

pull/142/merge
vahidrezanezhad 1 month ago
parent d9f79c3404
commit b622494f34

@ -264,9 +264,13 @@ class Eynollah:
else: else:
self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"#"/eynollah-textline_20210425" self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"#"/eynollah-textline_20210425"
if self.ocr: if self.ocr:
self.model_ocr_dir = dir_models + "/checkpoint-166692_printed_trocr" self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124"
self.model_tables = dir_models + "/eynollah-tables_20210319" if self.tables:
if self.light_version:
self.model_table_dir = dir_models + "/modelens_table_0t4_201124"
else:
self.model_table_dir = dir_models + "/eynollah-tables_20210319"
self.models = {} self.models = {}
@ -290,6 +294,9 @@ class Eynollah:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten") self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten")
if self.tables:
self.model_table = self.our_load_model(self.model_table_dir)
self.ls_imgs = os.listdir(self.dir_in) self.ls_imgs = os.listdir(self.dir_in)
@ -325,10 +332,14 @@ class Eynollah:
self.model_region_fl = self.our_load_model(self.model_region_dir_fully) self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
self.model_reading_order_machine = self.our_load_model(self.model_reading_order_machine_dir) self.model_reading_order_machine = self.our_load_model(self.model_reading_order_machine_dir)
if self.tables:
self.model_table = self.our_load_model(self.model_table_dir)
self.ls_imgs = os.listdir(self.dir_in) self.ls_imgs = os.listdir(self.dir_in)
def _cache_images(self, image_filename=None, image_pil=None): def _cache_images(self, image_filename=None, image_pil=None):
ret = {} ret = {}
t_c0 = time.time() t_c0 = time.time()
@ -2326,7 +2337,22 @@ class Eynollah:
###img_bin = np.copy(prediction_bin) ###img_bin = np.copy(prediction_bin)
###else: ###else:
###img_bin = np.copy(img_resized) ###img_bin = np.copy(img_resized)
if self.ocr and not self.input_binary:
if not self.dir_in:
model_bin, session_bin = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_resized, model_bin, n_batch_inference=5)
else:
prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
prediction_bin=prediction_bin[:,:,0]
prediction_bin = (prediction_bin[:,:]==0)*1
prediction_bin = prediction_bin*255
prediction_bin =np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
prediction_bin = prediction_bin.astype(np.uint16)
#img= np.copy(prediction_bin)
img_bin = np.copy(prediction_bin)
else:
img_bin = np.copy(img_resized) img_bin = np.copy(img_resized)
#print("inside 1 ", time.time()-t_in) #print("inside 1 ", time.time()-t_in)
@ -3175,13 +3201,23 @@ class Eynollah:
img_height_h = img_org.shape[0] img_height_h = img_org.shape[0]
img_width_h = img_org.shape[1] img_width_h = img_org.shape[1]
model_region, session_region = self.start_new_session_and_model(self.model_tables)
if self.dir_in:
pass
else:
self.model_table, _ = self.start_new_session_and_model(self.model_table_dir)
patches = False patches = False
if self.light_version:
prediction_table = self.do_prediction_new_concept(patches, img, self.model_table)
prediction_table = prediction_table.astype(np.int16)
return prediction_table[:,:,0]
else:
if num_col_classifier < 4 and num_col_classifier > 2: if num_col_classifier < 4 and num_col_classifier > 2:
prediction_table = self.do_prediction(patches, img, model_region) prediction_table = self.do_prediction(patches, img, self.model_table)
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), model_region) pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table)
pre_updown = cv2.flip(pre_updown, -1) pre_updown = cv2.flip(pre_updown, -1)
prediction_table[:,:,0][pre_updown[:,:,0]==1]=1 prediction_table[:,:,0][pre_updown[:,:,0]==1]=1
@ -3199,8 +3235,8 @@ class Eynollah:
img_new =np.ones((height_new,width_new,img.shape[2])).astype(float)*0 img_new =np.ones((height_new,width_new,img.shape[2])).astype(float)*0
img_new[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ] =img[:,:,:] img_new[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ] =img[:,:,:]
prediction_ext = self.do_prediction(patches, img_new, model_region) prediction_ext = self.do_prediction(patches, img_new, self.model_table)
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), model_region) pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table)
pre_updown = cv2.flip(pre_updown, -1) pre_updown = cv2.flip(pre_updown, -1)
prediction_table = prediction_ext[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ] prediction_table = prediction_ext[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ]
@ -3221,8 +3257,8 @@ class Eynollah:
img_new =np.ones((height_new,width_new,img.shape[2])).astype(float)*0 img_new =np.ones((height_new,width_new,img.shape[2])).astype(float)*0
img_new[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ] =img[:,:,:] img_new[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ] =img[:,:,:]
prediction_ext = self.do_prediction(patches, img_new, model_region) prediction_ext = self.do_prediction(patches, img_new, self.model_table)
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), model_region) pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), self.model_table)
pre_updown = cv2.flip(pre_updown, -1) pre_updown = cv2.flip(pre_updown, -1)
prediction_table = prediction_ext[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ] prediction_table = prediction_ext[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ]
@ -3235,10 +3271,10 @@ class Eynollah:
prediction_table = np.zeros(img.shape) prediction_table = np.zeros(img.shape)
img_w_half = int(img.shape[1]/2.) img_w_half = int(img.shape[1]/2.)
pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], model_region) pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], self.model_table)
pre2 = self.do_prediction(patches, img[:,img_w_half:,:], model_region) pre2 = self.do_prediction(patches, img[:,img_w_half:,:], self.model_table)
pre_full = self.do_prediction(patches, img[:,:,:], model_region) pre_full = self.do_prediction(patches, img[:,:,:], self.model_table)
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), model_region) pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), self.model_table)
pre_updown = cv2.flip(pre_updown, -1) pre_updown = cv2.flip(pre_updown, -1)
prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4) prediction_table_full_erode = cv2.erode(pre_full[:,:,0], KERNEL, iterations=4)
@ -3500,6 +3536,9 @@ class Eynollah:
#print(time.time()-t_0_box,'time box in 3.1') #print(time.time()-t_0_box,'time box in 3.1')
if self.tables: if self.tables:
if self.light_version:
pass
else:
text_regions_p_tables = np.copy(text_regions_p) text_regions_p_tables = np.copy(text_regions_p)
text_regions_p_tables[:,:][(table_prediction[:,:] == 1)] = 10 text_regions_p_tables[:,:][(table_prediction[:,:] == 1)] = 10
pixel_line = 3 pixel_line = 3
@ -3513,6 +3552,9 @@ class Eynollah:
self.logger.debug("len(boxes): %s", len(boxes_d)) self.logger.debug("len(boxes): %s", len(boxes_d))
if self.tables: if self.tables:
if self.light_version:
pass
else:
text_regions_p_tables = np.copy(text_regions_p_1_n) text_regions_p_tables = np.copy(text_regions_p_1_n)
text_regions_p_tables =np.round(text_regions_p_tables) text_regions_p_tables =np.round(text_regions_p_tables)
text_regions_p_tables[:,:][(text_regions_p_tables[:,:] != 3) & (table_prediction_n[:,:] == 1)] = 10 text_regions_p_tables[:,:][(text_regions_p_tables[:,:] != 3) & (table_prediction_n[:,:] == 1)] = 10
@ -3529,6 +3571,10 @@ class Eynollah:
self.logger.info("detecting boxes took %.1fs", time.time() - t1) self.logger.info("detecting boxes took %.1fs", time.time() - t1)
if self.tables: if self.tables:
if self.light_version:
text_regions_p[:,:][table_prediction[:,:]==1] = 10
img_revised_tab=text_regions_p[:,:]
else:
if np.abs(slope_deskew) < SLOPE_THRESHOLD: if np.abs(slope_deskew) < SLOPE_THRESHOLD:
img_revised_tab = np.copy(img_revised_tab2[:,:,0]) img_revised_tab = np.copy(img_revised_tab2[:,:,0])
img_revised_tab[:,:][(text_regions_p[:,:] == 1) & (img_revised_tab[:,:] != 10)] = 1 img_revised_tab[:,:][(text_regions_p[:,:] == 1) & (img_revised_tab[:,:] != 10)] = 1
@ -3542,6 +3588,9 @@ class Eynollah:
else: else:
img_revised_tab=text_regions_p[:,:] img_revised_tab=text_regions_p[:,:]
#img_revised_tab = text_regions_p[:, :] #img_revised_tab = text_regions_p[:, :]
if self.light_version:
polygons_of_images = return_contours_of_interested_region(text_regions_p, 2)
else:
polygons_of_images = return_contours_of_interested_region(img_revised_tab, 2) polygons_of_images = return_contours_of_interested_region(img_revised_tab, 2)
pixel_img = 4 pixel_img = 4
@ -3565,6 +3614,26 @@ class Eynollah:
self.logger.debug('enter run_boxes_full_layout') self.logger.debug('enter run_boxes_full_layout')
t_full0 = time.time() t_full0 = time.time()
if self.tables: if self.tables:
if self.light_version:
text_regions_p[:,:][table_prediction[:,:]==1] = 10
img_revised_tab=text_regions_p[:,:]
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
image_page_rotated_n,textline_mask_tot_d,text_regions_p_1_n , table_prediction_n = rotation_not_90_func(image_page, textline_mask_tot, text_regions_p, table_prediction, slope_deskew)
text_regions_p_1_n = resize_image(text_regions_p_1_n,text_regions_p.shape[0],text_regions_p.shape[1])
textline_mask_tot_d = resize_image(textline_mask_tot_d,text_regions_p.shape[0],text_regions_p.shape[1])
table_prediction_n = resize_image(table_prediction_n,text_regions_p.shape[0],text_regions_p.shape[1])
regions_without_separators_d=(text_regions_p_1_n[:,:] == 1)*1
regions_without_separators_d[table_prediction_n[:,:] == 1] = 1
else:
text_regions_p_1_n = None
textline_mask_tot_d = None
regions_without_separators_d = None
regions_without_separators = (text_regions_p[:,:] == 1)*1#( (text_regions_p[:,:]==1) | (text_regions_p[:,:]==2) )*1 #self.return_regions_without_seperators_new(text_regions_p[:,:,0],img_only_regions)
regions_without_separators[table_prediction == 1] = 1
else:
if np.abs(slope_deskew) >= SLOPE_THRESHOLD: if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
image_page_rotated_n,textline_mask_tot_d,text_regions_p_1_n , table_prediction_n = rotation_not_90_func(image_page, textline_mask_tot, text_regions_p, table_prediction, slope_deskew) image_page_rotated_n,textline_mask_tot_d,text_regions_p_1_n , table_prediction_n = rotation_not_90_func(image_page, textline_mask_tot, text_regions_p, table_prediction, slope_deskew)

Loading…
Cancel
Save