run_single: simplify; allow running TrOCR in non-fl mode, too

- refactor final `self.full_layout` conditional, removing copied code
- allow running `self.ocr` and `self.tr` branch in both cases (non/fl)
- when running TrOCR, use model / processor / device initialised during init
  (instead of ad-hoc loading)
This commit is contained in:
Robert Sachunsky 2025-10-06 17:24:50 +02:00
parent 6e57ab3741
commit 595ed02743

View file

@ -379,9 +379,14 @@ class Eynollah:
self.model_reading_order = self.our_load_model(self.model_reading_order_dir) self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
if self.ocr and self.tr: if self.ocr and self.tr:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available():
#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten") self.logger.info("Using GPU acceleration")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") self.device = torch.device("cuda:0")
else:
self.logger.info("Using CPU processing")
self.device = torch.device("cpu")
#self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
elif self.ocr and not self.tr: elif self.ocr and not self.tr:
model_ocr = load_model(self.model_ocr_dir , compile=False) model_ocr = load_model(self.model_ocr_dir , compile=False)
@ -4805,12 +4810,13 @@ class Eynollah:
slopes_marginals, mid_point_of_page_width) slopes_marginals, mid_point_of_page_width)
#print(len(polygons_of_marginals), len(ordered_left_marginals), len(ordered_right_marginals), 'marginals ordred') #print(len(polygons_of_marginals), len(ordered_left_marginals), len(ordered_right_marginals), 'marginals ordred')
if self.full_layout:
if np.abs(slope_deskew) >= SLOPE_THRESHOLD: if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
contours_only_text_parent_d_ordered = self.return_list_of_contours_with_desired_order( contours_only_text_parent_d_ordered = self.return_list_of_contours_with_desired_order(
contours_only_text_parent_d_ordered, index_by_text_par_con) contours_only_text_parent_d_ordered, index_by_text_par_con)
else: else:
contours_only_text_parent_d_ordered = None contours_only_text_parent_d_ordered = None
if self.full_layout:
if self.light_version: if self.light_version:
fun = check_any_text_region_in_model_one_is_main_or_header_light fun = check_any_text_region_in_model_one_is_main_or_header_light
else: else:
@ -4869,12 +4875,15 @@ class Eynollah:
splitter_y_new_d, regions_without_separators_d, matrix_of_lines_ch_d, splitter_y_new_d, regions_without_separators_d, matrix_of_lines_ch_d,
num_col_classifier, erosion_hurts, self.tables, self.right2left, num_col_classifier, erosion_hurts, self.tables, self.right2left,
logger=self.logger) logger=self.logger)
else:
contours_only_text_parent_h = []
contours_only_text_parent_h_d_ordered = []
if self.plotter: if self.plotter:
self.plotter.write_images_into_directory(polygons_of_images, image_page) self.plotter.write_images_into_directory(polygons_of_images, image_page)
t_order = time.time() t_order = time.time()
if self.full_layout: #if self.full_layout:
self.logger.info("Step 4/5: Reading Order Detection") self.logger.info("Step 4/5: Reading Order Detection")
if self.reading_order_machine_based: if self.reading_order_machine_based:
@ -4885,7 +4894,6 @@ class Eynollah:
self.logger.info("Headers ignored in reading order") self.logger.info("Headers ignored in reading order")
if self.reading_order_machine_based: if self.reading_order_machine_based:
tror = time.time()
order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model( order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model(
contours_only_text_parent, contours_only_text_parent_h, text_regions_p) contours_only_text_parent, contours_only_text_parent_h, text_regions_p)
else: else:
@ -4898,15 +4906,12 @@ class Eynollah:
boxes_d, textline_mask_tot_d) boxes_d, textline_mask_tot_d)
self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s") self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s")
if self.ocr and not self.tr: if self.ocr:
self.logger.info("Step 4.5/5: OCR Processing") self.logger.info("Step 4.5/5: OCR Processing")
if torch.cuda.is_available(): if not self.tr:
self.logger.info("Using GPU acceleration")
else:
self.logger.info("Using CPU processing")
gc.collect() gc.collect()
if len(all_found_textline_polygons)>0: if len(all_found_textline_polygons)>0:
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, self.prediction_model, image_page, all_found_textline_polygons, self.prediction_model,
@ -4941,73 +4946,19 @@ class Eynollah:
self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else: else:
ocr_all_textlines_drop = None ocr_all_textlines_drop = None
else: else:
ocr_all_textlines = None
ocr_all_textlines_marginals_left = None
ocr_all_textlines_marginals_right = None
ocr_all_textlines_h = None
ocr_all_textlines_drop = None
self.logger.info("Step 5/5: Output Generation")
pcgts = self.writer.build_pagexml_full_layout(
contours_only_text_parent, contours_only_text_parent_h, page_coord, order_text_new, id_of_texts_tot,
all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h,
polygons_of_images, contours_tables, polygons_of_drop_capitals,
polygons_of_marginals_left, polygons_of_marginals_right,
all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left, all_box_coord_marginals_right,
slopes, slopes_h, slopes_marginals_left, slopes_marginals_right,
cont_page, polygons_seplines, ocr_all_textlines, ocr_all_textlines_h,
ocr_all_textlines_marginals_left, ocr_all_textlines_marginals_right,
ocr_all_textlines_drop,
conf_contours_textregions, conf_contours_textregions_h)
return pcgts
contours_only_text_parent_h = []
self.logger.info("Step 4/5: Reading Order Detection")
if self.reading_order_machine_based:
self.logger.info("Using machine-based detection")
if self.right2left:
self.logger.info("Right-to-left mode enabled")
if self.headers_off:
self.logger.info("Headers ignored in reading order")
if self.reading_order_machine_based:
order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model(
contours_only_text_parent, contours_only_text_parent_h, text_regions_p)
else:
if np.abs(slope_deskew) < SLOPE_THRESHOLD:
order_text_new, id_of_texts_tot = self.do_order_of_regions(
contours_only_text_parent, contours_only_text_parent_h, boxes, textline_mask_tot)
else:
contours_only_text_parent_d_ordered = self.return_list_of_contours_with_desired_order(
contours_only_text_parent_d_ordered, index_by_text_par_con)
order_text_new, id_of_texts_tot = self.do_order_of_regions(
contours_only_text_parent_d_ordered, contours_only_text_parent_h, boxes_d, textline_mask_tot_d)
if self.ocr and self.tr:
self.logger.info("Step 4.5/5: OCR Processing")
if torch.cuda.is_available():
self.logger.info("Using GPU acceleration")
else:
self.logger.info("Using CPU processing")
if self.light_version: if self.light_version:
self.logger.info("Using light version OCR") self.logger.info("Using light version OCR")
if self.textline_light: if self.textline_light:
self.logger.info("Using light text line detection for OCR") self.logger.info("Using light text line detection for OCR")
self.logger.info("Processing text lines...") self.logger.info("Processing text lines...")
device = cuda.get_current_device() self.device.reset()
device.reset()
gc.collect() gc.collect()
model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
torch.cuda.empty_cache() torch.cuda.empty_cache()
model_ocr.to(device) self.model_ocr.to(self.device)
ind_tot = 0 ind_tot = 0
#cv2.imwrite('./img_out.png', image_page) #cv2.imwrite('./img_out.png', image_page)
@ -5043,37 +4994,33 @@ class Eynollah:
img_croped = img_poly_on_img[y:y+h, x:x+w, :] img_croped = img_poly_on_img[y:y+h, x:x+w, :]
#cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped) #cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped)
text_ocr = self.return_ocr_of_textline_without_common_section( text_ocr = self.return_ocr_of_textline_without_common_section(
img_croped, model_ocr, processor, device, w, h2w_ratio, ind_tot) img_croped, self.model_ocr, self.processor, self.device, w, h2w_ratio, ind_tot)
ocr_textline_in_textregion.append(text_ocr) ocr_textline_in_textregion.append(text_ocr)
ind_tot = ind_tot +1 ind_tot = ind_tot +1
ocr_all_textlines.append(ocr_textline_in_textregion) ocr_all_textlines.append(ocr_textline_in_textregion)
elif self.ocr and not self.tr:
gc.collect()
if len(all_found_textline_polygons)>0:
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons, self.prediction_model,
self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if all_found_textline_polygons_marginals_left and len(all_found_textline_polygons_marginals_left)>0:
ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_left, self.prediction_model,
self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
if all_found_textline_polygons_marginals_right and len(all_found_textline_polygons_marginals_right)>0:
ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines(
image_page, all_found_textline_polygons_marginals_right, self.prediction_model,
self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else: else:
ocr_all_textlines = None ocr_all_textlines = None
ocr_all_textlines_marginals_left = None ocr_all_textlines_marginals_left = None
ocr_all_textlines_marginals_right = None ocr_all_textlines_marginals_right = None
self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s") ocr_all_textlines_h = None
ocr_all_textlines_drop = None
self.logger.info("Step 5/5: Output Generation") self.logger.info("Step 5/5: Output Generation")
self.logger.info("Generating PAGE-XML output")
if self.full_layout:
pcgts = self.writer.build_pagexml_full_layout(
contours_only_text_parent, contours_only_text_parent_h, page_coord, order_text_new, id_of_texts_tot,
all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h,
polygons_of_images, contours_tables, polygons_of_drop_capitals,
polygons_of_marginals_left, polygons_of_marginals_right,
all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right,
all_box_coord_marginals_left, all_box_coord_marginals_right,
slopes, slopes_h, slopes_marginals_left, slopes_marginals_right,
cont_page, polygons_seplines, ocr_all_textlines, ocr_all_textlines_h,
ocr_all_textlines_marginals_left, ocr_all_textlines_marginals_right,
ocr_all_textlines_drop,
conf_contours_textregions, conf_contours_textregions_h)
else:
pcgts = self.writer.build_pagexml_no_full_layout( pcgts = self.writer.build_pagexml_no_full_layout(
txt_con_org, page_coord, order_text_new, id_of_texts_tot, txt_con_org, page_coord, order_text_new, id_of_texts_tot,
all_found_textline_polygons, all_box_coord, polygons_of_images, all_found_textline_polygons, all_box_coord, polygons_of_images,