new page extraction model integration

This commit is contained in:
vahidrezanezhad 2025-09-15 13:36:58 +02:00
parent fdcae8dd6e
commit 6e008345a0

View file

@ -285,7 +285,7 @@ class Eynollah:
#"/eynollah-full-regions-1column_20210425"
self.model_region_dir_fully_np = dir_models + "/modelens_full_lay_1__4_3_091124"
#self.model_region_dir_fully = dir_models + "/eynollah-full-regions-3+column_20210425"
self.model_page_dir = dir_models + "/eynollah-page-extraction_20210425"
self.model_page_dir = dir_models + "/model_ens_page"
self.model_region_dir_p_ens = dir_models + "/eynollah-main-regions-ensembled_20210425"
self.model_region_dir_p_ens_light = dir_models + "/eynollah-main-regions_20220314"
self.model_region_dir_p_ens_light_only_images_extraction = dir_models + "/eynollah-main-regions_20231127_672_org_ens_11_13_16_17_18"
@ -1591,11 +1591,11 @@ class Eynollah:
self.logger.debug("enter extract_page")
cont_page = []
if not self.ignore_page_extraction:
img = cv2.GaussianBlur(self.image, (5, 5), 0)
img = np.copy(self.image)#cv2.GaussianBlur(self.image, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page)
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
thresh = cv2.dilate(thresh, KERNEL, iterations=3)
##thresh = cv2.dilate(thresh, KERNEL, iterations=3)
contours, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
if len(contours)>0:
@ -1603,24 +1603,25 @@ class Eynollah:
for j in range(len(contours))])
cnt = contours[np.argmax(cnt_size)]
x, y, w, h = cv2.boundingRect(cnt)
if x <= 30:
w += x
x = 0
if (self.image.shape[1] - (x + w)) <= 30:
w = w + (self.image.shape[1] - (x + w))
if y <= 30:
h = h + y
y = 0
if (self.image.shape[0] - (y + h)) <= 30:
h = h + (self.image.shape[0] - (y + h))
#if x <= 30:
#w += x
#x = 0
#if (self.image.shape[1] - (x + w)) <= 30:
#w = w + (self.image.shape[1] - (x + w))
#if y <= 30:
#h = h + y
#y = 0
#if (self.image.shape[0] - (y + h)) <= 30:
#h = h + (self.image.shape[0] - (y + h))
box = [x, y, w, h]
else:
box = [0, 0, img.shape[1], img.shape[0]]
cropped_page, page_coord = crop_image_inside_box(box, self.image)
cont_page.append(np.array([[page_coord[2], page_coord[0]],
[page_coord[3], page_coord[0]],
[page_coord[3], page_coord[1]],
[page_coord[2], page_coord[1]]]))
cont_page = cnt
#cont_page.append(np.array([[page_coord[2], page_coord[0]],
#[page_coord[3], page_coord[0]],
#[page_coord[3], page_coord[1]],
#[page_coord[2], page_coord[1]]]))
self.logger.debug("exit extract_page")
else:
box = [0, 0, self.image.shape[1], self.image.shape[0]]
@ -3064,10 +3065,20 @@ class Eynollah:
if self.plotter:
self.plotter.save_page_image(image_page)
mask_page = np.zeros((text_regions_p_1.shape[0], text_regions_p_1.shape[1])).astype(np.int8)
mask_page = cv2.fillPoly(mask_page, pts=[cont_page], color=(1,))
text_regions_p_1[mask_page==0] = 0
textline_mask_tot_ea[mask_page==0] = 0
text_regions_p_1 = text_regions_p_1[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
textline_mask_tot_ea = textline_mask_tot_ea[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
img_bin_light = img_bin_light[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
###text_regions_p_1 = text_regions_p_1[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
###textline_mask_tot_ea = textline_mask_tot_ea[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
###img_bin_light = img_bin_light[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3]]
mask_images = (text_regions_p_1[:, :] == 2) * 1
mask_images = mask_images.astype(np.uint8)
mask_images = cv2.erode(mask_images[:, :], KERNEL, iterations=10)
@ -5300,7 +5311,11 @@ class Eynollah_ocr:
cropped_lines_region_indexer = []
cropped_lines_meging_indexing = []
extracted_texts = []
indexer_text_region = 0
indexer_b_s = 0
for nn in root1.iter(region_tags):
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
@ -5325,40 +5340,105 @@ class Eynollah_ocr:
img_crop = img_poly_on_img[y:y+h, x:x+w, :]
img_crop[mask_poly==0] = 255
if h2w_ratio > 0.1:
cropped_lines.append(resize_image(img_crop, tr_ocr_input_height_and_width, tr_ocr_input_height_and_width) )
cropped_lines_meging_indexing.append(0)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device))
generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
else:
splited_images, _ = return_textlines_split_if_needed(img_crop, None)
#print(splited_images)
if splited_images:
cropped_lines.append(resize_image(splited_images[0], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(1)
cropped_lines.append(resize_image(splited_images[1], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(-1)
else:
cropped_lines.append(img_crop)
cropped_lines_meging_indexing.append(0)
indexer_text_region = indexer_text_region +1
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
extracted_texts = []
n_iterations = math.ceil(len(cropped_lines) / self.b_s)
for i in range(n_iterations):
if i==(n_iterations-1):
n_start = i*self.b_s
imgs = cropped_lines[n_start:]
else:
n_start = i*self.b_s
n_end = (i+1)*self.b_s
imgs = cropped_lines[n_start:n_end]
pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device))
generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
cropped_lines.append(resize_image(splited_images[1], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(-1)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device))
generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
else:
cropped_lines.append(img_crop)
cropped_lines_meging_indexing.append(0)
indexer_b_s+=1
if indexer_b_s==self.b_s:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device))
generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
indexer_text_region = indexer_text_region +1
if indexer_b_s!=0:
imgs = cropped_lines[:]
cropped_lines = []
indexer_b_s = 0
pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device))
generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
####extracted_texts = []
####n_iterations = math.ceil(len(cropped_lines) / self.b_s)
####for i in range(n_iterations):
####if i==(n_iterations-1):
####n_start = i*self.b_s
####imgs = cropped_lines[n_start:]
####else:
####n_start = i*self.b_s
####n_end = (i+1)*self.b_s
####imgs = cropped_lines[n_start:n_end]
####pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values
####generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device))
####generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
####extracted_texts = extracted_texts + generated_text_merged
del cropped_lines
gc.collect()
@ -5409,10 +5489,19 @@ class Eynollah_ocr:
#print(time.time() - t0 ,'elapsed time')
indexer = 0
indexer_textregion = 0
for nn in root1.iter(region_tags):
#id_textregion = nn.attrib['id']
#id_textregions.append(id_textregion)
#textregions_by_existing_ids.append(text_by_textregion[indexer_textregion])
is_textregion_text = False
for childtest in nn:
if childtest.tag.endswith("TextEquiv"):
is_textregion_text = True
if not is_textregion_text:
text_subelement_textregion = ET.SubElement(nn, 'TextEquiv')
unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode')
@ -5420,20 +5509,51 @@ class Eynollah_ocr:
has_textline = False
for child_textregion in nn:
if child_textregion.tag.endswith("TextLine"):
is_textline_text = False
for childtest2 in child_textregion:
if childtest2.tag.endswith("TextEquiv"):
is_textline_text = True
if not is_textline_text:
text_subelement = ET.SubElement(child_textregion, 'TextEquiv')
##text_subelement.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
unicode_textline = ET.SubElement(text_subelement, 'Unicode')
unicode_textline.text = extracted_texts_merged[indexer]
else:
for childtest3 in child_textregion:
if childtest3.tag.endswith("TextEquiv"):
for child_uc in childtest3:
if child_uc.tag.endswith("Unicode"):
##childtest3.set('conf', f"{extracted_conf_value_merged[indexer]:.2f}")
child_uc.text = extracted_texts_merged[indexer]
indexer = indexer + 1
has_textline = True
if has_textline:
if is_textregion_text:
for child4 in nn:
if child4.tag.endswith("TextEquiv"):
for childtr_uc in child4:
if childtr_uc.tag.endswith("Unicode"):
childtr_uc.text = text_by_textregion[indexer_textregion]
else:
unicode_textregion.text = text_by_textregion[indexer_textregion]
indexer_textregion = indexer_textregion + 1
###sample_order = [(id_to_order[tid], text) for tid, text in zip(id_textregions, textregions_by_existing_ids) if tid in id_to_order]
##ordered_texts_sample = [text for _, text in sorted(sample_order)]
##tot_page_text = ' '.join(ordered_texts_sample)
##for page_element in root1.iter(link+'Page'):
##text_page = ET.SubElement(page_element, 'TextEquiv')
##unicode_textpage = ET.SubElement(text_page, 'Unicode')
##unicode_textpage.text = tot_page_text
ET.register_namespace("",name_space)
tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf8",default_namespace=None)
#print("Job done in %.1fs", time.time() - t0)
else:
###max_len = 280#512#280#512
###padding_token = 1500#299#1500#299