run_single: simplify; allow running TrOCR in non-fl mode, too

- refactor final `self.full_layout` conditional, removing copied code
- allow running `self.ocr` and `self.tr` branch in both cases (non/fl)
- when running TrOCR, use model / processor / device initialised during init
  (instead of ad-hoc loading)
This commit is contained in:
Robert Sachunsky 2025-10-06 17:24:50 +02:00
parent 6e57ab3741
commit 595ed02743

View file

@ -379,9 +379,14 @@ class Eynollah:
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
if self.ocr and self.tr:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
self.device = torch.device("cuda:0")
else:
self.logger.info("Using CPU processing")
self.device = torch.device("cpu")
#self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
elif self.ocr and not self.tr:
model_ocr = load_model(self.model_ocr_dir , compile=False)
@ -4805,12 +4810,13 @@ class Eynollah:
slopes_marginals, mid_point_of_page_width)
#print(len(polygons_of_marginals), len(ordered_left_marginals), len(ordered_right_marginals), 'marginals ordred')
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
contours_only_text_parent_d_ordered = self.return_list_of_contours_with_desired_order(
contours_only_text_parent_d_ordered, index_by_text_par_con)
else:
contours_only_text_parent_d_ordered = None
if self.full_layout:
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
contours_only_text_parent_d_ordered = self.return_list_of_contours_with_desired_order(
contours_only_text_parent_d_ordered, index_by_text_par_con)
else:
contours_only_text_parent_d_ordered = None
if self.light_version:
fun = check_any_text_region_in_model_one_is_main_or_header_light
else:
@ -4869,44 +4875,43 @@ class Eynollah:
splitter_y_new_d, regions_without_separators_d, matrix_of_lines_ch_d,
num_col_classifier, erosion_hurts, self.tables, self.right2left,
logger=self.logger)
else:
contours_only_text_parent_h = []
contours_only_text_parent_h_d_ordered = []
if self.plotter:
self.plotter.write_images_into_directory(polygons_of_images, image_page)
t_order = time.time()
if self.full_layout:
self.logger.info("Step 4/5: Reading Order Detection")
if self.reading_order_machine_based:
self.logger.info("Using machine-based detection")
if self.right2left:
self.logger.info("Right-to-left mode enabled")
if self.headers_off:
self.logger.info("Headers ignored in reading order")
#if self.full_layout:
self.logger.info("Step 4/5: Reading Order Detection")
if self.reading_order_machine_based:
tror = time.time()
order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model(
contours_only_text_parent, contours_only_text_parent_h, text_regions_p)
if self.reading_order_machine_based:
self.logger.info("Using machine-based detection")
if self.right2left:
self.logger.info("Right-to-left mode enabled")
if self.headers_off:
self.logger.info("Headers ignored in reading order")
if self.reading_order_machine_based:
order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model(
contours_only_text_parent, contours_only_text_parent_h, text_regions_p)
else:
if np.abs(slope_deskew) < SLOPE_THRESHOLD:
order_text_new, id_of_texts_tot = self.do_order_of_regions(
contours_only_text_parent, contours_only_text_parent_h, boxes, textline_mask_tot)
else:
if np.abs(slope_deskew) < SLOPE_THRESHOLD:
order_text_new, id_of_texts_tot = self.do_order_of_regions(
contours_only_text_parent, contours_only_text_parent_h, boxes, textline_mask_tot)
else:
order_text_new, id_of_texts_tot = self.do_order_of_regions(
contours_only_text_parent_d_ordered, contours_only_text_parent_h_d_ordered,
boxes_d, textline_mask_tot_d)
self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s")
order_text_new, id_of_texts_tot = self.do_order_of_regions(
contours_only_text_parent_d_ordered, contours_only_text_parent_h_d_ordered,
boxes_d, textline_mask_tot_d)
self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s")
if self.ocr and not self.tr:
self.logger.info("Step 4.5/5: OCR Processing")
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
else:
self.logger.info("Using CPU processing")
if self.ocr:
self.logger.info("Step 4.5/5: OCR Processing")
if not self.tr:
gc.collect()
if len(all_found_textline_polygons)>0:
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, self.prediction_model,
@ -4941,15 +4946,68 @@ class Eynollah:
self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else:
ocr_all_textlines_drop = None
else:
ocr_all_textlines = None
ocr_all_textlines_marginals_left = None
ocr_all_textlines_marginals_right = None
ocr_all_textlines_h = None
ocr_all_textlines_drop = None
if self.light_version:
self.logger.info("Using light version OCR")
if self.textline_light:
self.logger.info("Using light text line detection for OCR")
self.logger.info("Processing text lines...")
self.device.reset()
gc.collect()
torch.cuda.empty_cache()
self.model_ocr.to(self.device)
ind_tot = 0
#cv2.imwrite('./img_out.png', image_page)
ocr_all_textlines = []
for indexing, ind_poly_first in enumerate(all_found_textline_polygons):
ocr_textline_in_textregion = []
for indexing2, ind_poly in enumerate(ind_poly_first):
if not (self.textline_light or self.curved_line):
ind_poly = copy.deepcopy(ind_poly)
box_ind = all_box_coord[indexing]
#print(ind_poly,np.shape(ind_poly), 'ind_poly')
#print(box_ind)
ind_poly = return_textline_contour_with_added_box_coordinate(ind_poly, box_ind)
#print(ind_poly_copy)
ind_poly[ind_poly<0] = 0
x, y, w, h = cv2.boundingRect(ind_poly)
#print(ind_poly_copy, np.shape(ind_poly_copy))
#print(x, y, w, h, h/float(w),'ratio')
h2w_ratio = h/float(w)
mask_poly = np.zeros(image_page.shape)
if not self.light_version:
img_poly_on_img = np.copy(image_page)
else:
img_poly_on_img = np.copy(img_bin_light)
mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1))
if self.textline_light:
mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1)
img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255
img_croped = img_poly_on_img[y:y+h, x:x+w, :]
#cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped)
text_ocr = self.return_ocr_of_textline_without_common_section(
img_croped, self.model_ocr, self.processor, self.device, w, h2w_ratio, ind_tot)
ocr_textline_in_textregion.append(text_ocr)
ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion)
else:
ocr_all_textlines = None
ocr_all_textlines_marginals_left = None
ocr_all_textlines_marginals_right = None
ocr_all_textlines_h = None
ocr_all_textlines_drop = None
self.logger.info("Step 5/5: Output Generation")
self.logger.info("Step 5/5: Output Generation")
if self.full_layout:
pcgts = self.writer.build_pagexml_full_layout(
contours_only_text_parent, contours_only_text_parent_h, page_coord, order_text_new, id_of_texts_tot,
all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h,
@ -4962,129 +5020,18 @@ class Eynollah:
ocr_all_textlines_marginals_left, ocr_all_textlines_marginals_right,
ocr_all_textlines_drop,
conf_contours_textregions, conf_contours_textregions_h)
return pcgts
contours_only_text_parent_h = []
self.logger.info("Step 4/5: Reading Order Detection")
if self.reading_order_machine_based:
self.logger.info("Using machine-based detection")
if self.right2left:
self.logger.info("Right-to-left mode enabled")
if self.headers_off:
self.logger.info("Headers ignored in reading order")
if self.reading_order_machine_based:
order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model(
contours_only_text_parent, contours_only_text_parent_h, text_regions_p)
else:
if np.abs(slope_deskew) < SLOPE_THRESHOLD:
order_text_new, id_of_texts_tot = self.do_order_of_regions(
contours_only_text_parent, contours_only_text_parent_h, boxes, textline_mask_tot)
else:
contours_only_text_parent_d_ordered = self.return_list_of_contours_with_desired_order(
contours_only_text_parent_d_ordered, index_by_text_par_con)
order_text_new, id_of_texts_tot = self.do_order_of_regions(
contours_only_text_parent_d_ordered, contours_only_text_parent_h, boxes_d, textline_mask_tot_d)
if self.ocr and self.tr:
self.logger.info("Step 4.5/5: OCR Processing")
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
else:
self.logger.info("Using CPU processing")
if self.light_version:
self.logger.info("Using light version OCR")
if self.textline_light:
self.logger.info("Using light text line detection for OCR")
self.logger.info("Processing text lines...")
pcgts = self.writer.build_pagexml_no_full_layout(
txt_con_org, page_coord, order_text_new, id_of_texts_tot,
all_found_textline_polygons, all_box_coord, polygons_of_images,
polygons_of_marginals_left, polygons_of_marginals_right,
all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left, all_box_coord_marginals_right,
slopes, slopes_marginals_left, slopes_marginals_right,
cont_page, polygons_seplines, contours_tables, ocr_all_textlines,
ocr_all_textlines_marginals_left, ocr_all_textlines_marginals_right,
conf_contours_textregions)
device = cuda.get_current_device()
device.reset()
gc.collect()
model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
torch.cuda.empty_cache()
model_ocr.to(device)
ind_tot = 0
#cv2.imwrite('./img_out.png', image_page)
ocr_all_textlines = []
for indexing, ind_poly_first in enumerate(all_found_textline_polygons):
ocr_textline_in_textregion = []
for indexing2, ind_poly in enumerate(ind_poly_first):
if not (self.textline_light or self.curved_line):
ind_poly = copy.deepcopy(ind_poly)
box_ind = all_box_coord[indexing]
#print(ind_poly,np.shape(ind_poly), 'ind_poly')
#print(box_ind)
ind_poly = return_textline_contour_with_added_box_coordinate(ind_poly, box_ind)
#print(ind_poly_copy)
ind_poly[ind_poly<0] = 0
x, y, w, h = cv2.boundingRect(ind_poly)
#print(ind_poly_copy, np.shape(ind_poly_copy))
#print(x, y, w, h, h/float(w),'ratio')
h2w_ratio = h/float(w)
mask_poly = np.zeros(image_page.shape)
if not self.light_version:
img_poly_on_img = np.copy(image_page)
else:
img_poly_on_img = np.copy(img_bin_light)
mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1))
if self.textline_light:
mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1)
img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255
img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255
img_croped = img_poly_on_img[y:y+h, x:x+w, :]
#cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped)
text_ocr = self.return_ocr_of_textline_without_common_section(
img_croped, model_ocr, processor, device, w, h2w_ratio, ind_tot)
ocr_textline_in_textregion.append(text_ocr)
ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion)
elif self.ocr and not self.tr:
gc.collect()
if len(all_found_textline_polygons)>0:
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, self.prediction_model,
self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if all_found_textline_polygons_marginals_left and len(all_found_textline_polygons_marginals_left)>0:
ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_left, self.prediction_model,
self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if all_found_textline_polygons_marginals_right and len(all_found_textline_polygons_marginals_right)>0:
ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_right, self.prediction_model,
self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else:
ocr_all_textlines = None
ocr_all_textlines_marginals_left = None
ocr_all_textlines_marginals_right = None
self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s")
self.logger.info("Step 5/5: Output Generation")
self.logger.info("Generating PAGE-XML output")
pcgts = self.writer.build_pagexml_no_full_layout(
txt_con_org, page_coord, order_text_new, id_of_texts_tot,
all_found_textline_polygons, all_box_coord, polygons_of_images,
polygons_of_marginals_left, polygons_of_marginals_right,
all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left, all_box_coord_marginals_right,
slopes, slopes_marginals_left, slopes_marginals_right,
cont_page, polygons_seplines, contours_tables, ocr_all_textlines,
ocr_all_textlines_marginals_left, ocr_all_textlines_marginals_right,
conf_contours_textregions)
return pcgts