From 6e008345a05a4ae47f7a36b82a755219a3b2d868 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Mon, 15 Sep 2025 13:36:58 +0200 Subject: [PATCH] new page extraction model integration --- src/eynollah/eynollah.py | 200 +++++++++++++++++++++++++++++++-------- 1 file changed, 160 insertions(+), 40 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index ec2900f..3288b75 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -285,7 +285,7 @@ class Eynollah: #"/eynollah-full-regions-1column_20210425" self.model_region_dir_fully_np = dir_models + "/modelens_full_lay_1__4_3_091124" #self.model_region_dir_fully = dir_models + "/eynollah-full-regions-3+column_20210425" - self.model_page_dir = dir_models + "/eynollah-page-extraction_20210425" + self.model_page_dir = dir_models + "/model_ens_page" self.model_region_dir_p_ens = dir_models + "/eynollah-main-regions-ensembled_20210425" self.model_region_dir_p_ens_light = dir_models + "/eynollah-main-regions_20220314" self.model_region_dir_p_ens_light_only_images_extraction = dir_models + "/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18" @@ -1591,11 +1591,11 @@ class Eynollah: self.logger.debug("enter extract_page") cont_page = [] if not self.ignore_page_extraction: - img = cv2.GaussianBlur(self.image, (5, 5), 0) + img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0) img_page_prediction = self.do_prediction(False, img, self.model_page) imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(imgray, 0, 255, 0) - thresh = cv2.dilate(thresh, KERNEL, iterations=3) + ##thresh = cv2.dilate(thresh, KERNEL, iterations=3) contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) if len(contours)>0: @@ -1603,24 +1603,25 @@ class Eynollah: for j in range(len(contours))]) cnt = contours[np.argmax(cnt_size)] x, y, w, h = cv2.boundingRect(cnt) - if x <= 30: - w += x - x = 0 - if (self.image.shape[1] - (x + w)) <= 30: - w = w + (self.image.shape[1] - (x + w)) - if y <= 30: - h = h + y - y = 0 - if (self.image.shape[0] - (y + h)) <= 30: - h = h + (self.image.shape[0] - (y + h)) + #if x <= 30: + #w += x + #x = 0 + #if (self.image.shape[1] - (x + w)) <= 30: + #w = w + (self.image.shape[1] - (x + w)) + #if y <= 30: + #h = h + y + #y = 0 + #if (self.image.shape[0] - (y + h)) <= 30: + #h = h + (self.image.shape[0] - (y + h)) box = [x, y, w, h] else: box = [0, 0, img.shape[1], img.shape[0]] cropped_page, page_coord = crop_image_inside_box(box, self.image) - cont_page.append(np.array([[page_coord[2], page_coord[0]], - [page_coord[3], page_coord[0]], - [page_coord[3], page_coord[1]], - [page_coord[2], page_coord[1]]])) + cont_page = cnt + #cont_page.append(np.array([[page_coord[2], page_coord[0]], + #[page_coord[3], page_coord[0]], + #[page_coord[3], page_coord[1]], + #[page_coord[2], page_coord[1]]])) self.logger.debug("exit extract_page") else: box = [0, 0, self.image.shape[1], self.image.shape[0]] @@ -3063,10 +3064,20 @@ class Eynollah: if self.plotter: self.plotter.save_page_image(image_page) - + + mask_page = np.zeros((text_regions_p_1.shape[0], text_regions_p_1.shape[1])).astype(np.int8) + mask_page = cv2.fillPoly(mask_page, pts=[cont_page], color=(1,)) + + text_regions_p_1[mask_page==0] = 0 + textline_mask_tot_ea[mask_page==0] = 0 + text_regions_p_1 = text_regions_p_1[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] textline_mask_tot_ea = textline_mask_tot_ea[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] img_bin_light = img_bin_light[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] + + ###text_regions_p_1 = text_regions_p_1[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] + ###textline_mask_tot_ea = textline_mask_tot_ea[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] + ###img_bin_light = img_bin_light[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]] mask_images = (text_regions_p_1[:, :] == 2) * 1 mask_images = mask_images.astype(np.uint8) @@ -5299,8 +5310,12 @@ class Eynollah_ocr: cropped_lines = [] cropped_lines_region_indexer = [] cropped_lines_meging_indexing = [] + + extracted_texts = [] indexer_text_region = 0 + indexer_b_s = 0 + for nn in root1.iter(region_tags): for child_textregion in nn: if child_textregion.tag.endswith("TextLine"): @@ -5325,40 +5340,105 @@ class Eynollah_ocr: img_crop = img_poly_on_img[y:y+h, x:x+w, :] img_crop[mask_poly==0] = 255 + if h2w_ratio > 0.1: cropped_lines.append(resize_image(img_crop, tr_ocr_input_height_and_width, tr_ocr_input_height_and_width) ) cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) + generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + else: splited_images, _ = return_textlines_split_if_needed(img_crop, None) #print(splited_images) if splited_images: cropped_lines.append(resize_image(splited_images[0], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width)) cropped_lines_meging_indexing.append(1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) + generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + cropped_lines.append(resize_image(splited_images[1], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width)) cropped_lines_meging_indexing.append(-1) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) + generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + else: cropped_lines.append(img_crop) cropped_lines_meging_indexing.append(0) + indexer_b_s+=1 + + if indexer_b_s==self.b_s: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + + pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values + generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) + generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) + + extracted_texts = extracted_texts + generated_text_merged + + + indexer_text_region = indexer_text_region +1 - - extracted_texts = [] - n_iterations = math.ceil(len(cropped_lines) / self.b_s) - - for i in range(n_iterations): - if i==(n_iterations-1): - n_start = i*self.b_s - imgs = cropped_lines[n_start:] - else: - n_start = i*self.b_s - n_end = (i+1)*self.b_s - imgs = cropped_lines[n_start:n_end] + if indexer_b_s!=0: + imgs = cropped_lines[:] + cropped_lines = [] + indexer_b_s = 0 + pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) extracted_texts = extracted_texts + generated_text_merged + ####extracted_texts = [] + ####n_iterations = math.ceil(len(cropped_lines) / self.b_s) + + ####for i in range(n_iterations): + ####if i==(n_iterations-1): + ####n_start = i*self.b_s + ####imgs = cropped_lines[n_start:] + ####else: + ####n_start = i*self.b_s + ####n_end = (i+1)*self.b_s + ####imgs = cropped_lines[n_start:n_end] + ####pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values + ####generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) + ####generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) + + ####extracted_texts = extracted_texts + generated_text_merged + del cropped_lines gc.collect() @@ -5409,31 +5489,71 @@ class Eynollah_ocr: #print(time.time() - t0 ,'elapsed time') - indexer = 0 indexer_textregion = 0 for nn in root1.iter(region_tags): - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') + #id_textregion = nn.attrib['id'] + #id_textregions.append(id_textregion) + #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) + + is_textregion_text = False + for childtest in nn: + if childtest.tag.endswith("TextEquiv"): + is_textregion_text = True + + if not is_textregion_text: + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') has_textline = False for child_textregion in nn: if child_textregion.tag.endswith("TextLine"): - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] + + is_textline_text = False + for childtest2 in child_textregion: + if childtest2.tag.endswith("TextEquiv"): + is_textline_text = True + + + if not is_textline_text: + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + ##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts_merged[indexer] + else: + for childtest3 in child_textregion: + if childtest3.tag.endswith("TextEquiv"): + for child_uc in childtest3: + if child_uc.tag.endswith("Unicode"): + ##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}") + child_uc.text = extracted_texts_merged[indexer] + indexer = indexer + 1 has_textline = True if has_textline: - unicode_textregion.text = text_by_textregion[indexer_textregion] + if is_textregion_text: + for child4 in nn: + if child4.tag.endswith("TextEquiv"): + for childtr_uc in child4: + if childtr_uc.tag.endswith("Unicode"): + childtr_uc.text = text_by_textregion[indexer_textregion] + else: + unicode_textregion.text = text_by_textregion[indexer_textregion] indexer_textregion = indexer_textregion + 1 - - + ###sample_order = [(id_to_order[tid], text) for tid, text in zip(id_textregions, textregions_by_existing_ids) if tid in id_to_order] + + ##ordered_texts_sample = [text for _, text in sorted(sample_order)] + ##tot_page_text = ' '.join(ordered_texts_sample) + + ##for page_element in root1.iter(link+'Page'): + ##text_page = ET.SubElement(page_element, 'TextEquiv') + ##unicode_textpage = ET.SubElement(text_page, 'Unicode') + ##unicode_textpage.text = tot_page_text + ET.register_namespace("",name_space) tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) - #print("Job done in %.1fs", time.time() - t0) else: ###max_len = 280#512#280#512 ###padding_token = 1500#299#1500#299