From 595ed02743afc3ab8359de5f6feb0ca680546599 Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Mon, 6 Oct 2025 17:24:50 +0200 Subject: [PATCH] run_single: simplify; allow running TrOCR in non-fl mode, too - refactor final `self.full_layout` conditional, removing copied code - allow running `self.ocr` and `self.tr` branch in both cases (non/fl) - when running TrOCR, use model / processor / device initialised during init (instead of ad-hoc loading) --- src/eynollah/eynollah.py | 277 ++++++++++++++++----------------------- 1 file changed, 112 insertions(+), 165 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 834ecf3..079cf8c 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -379,9 +379,14 @@ class Eynollah: self.model_reading_order = self.our_load_model(self.model_reading_order_dir) if self.ocr and self.tr: self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - #("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten") - self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") + if torch.cuda.is_available(): + self.logger.info("Using GPU acceleration") + self.device = torch.device("cuda:0") + else: + self.logger.info("Using CPU processing") + self.device = torch.device("cpu") + #self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") + self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") elif self.ocr and not self.tr: model_ocr = load_model(self.model_ocr_dir , compile=False) @@ -4805,12 +4810,13 @@ class Eynollah: slopes_marginals, mid_point_of_page_width) #print(len(polygons_of_marginals), len(ordered_left_marginals), len(ordered_right_marginals), 'marginals ordred') + if np.abs(slope_deskew) >= SLOPE_THRESHOLD: + contours_only_text_parent_d_ordered = self.return_list_of_contours_with_desired_order( + contours_only_text_parent_d_ordered, index_by_text_par_con) + else: + contours_only_text_parent_d_ordered = None + if self.full_layout: - if np.abs(slope_deskew) >= SLOPE_THRESHOLD: - contours_only_text_parent_d_ordered = self.return_list_of_contours_with_desired_order( - contours_only_text_parent_d_ordered, index_by_text_par_con) - else: - contours_only_text_parent_d_ordered = None if self.light_version: fun = check_any_text_region_in_model_one_is_main_or_header_light else: @@ -4869,44 +4875,43 @@ class Eynollah: splitter_y_new_d, regions_without_separators_d, matrix_of_lines_ch_d, num_col_classifier, erosion_hurts, self.tables, self.right2left, logger=self.logger) + else: + contours_only_text_parent_h = [] + contours_only_text_parent_h_d_ordered = [] if self.plotter: self.plotter.write_images_into_directory(polygons_of_images, image_page) t_order = time.time() - if self.full_layout: - self.logger.info("Step 4/5: Reading Order Detection") - - if self.reading_order_machine_based: - self.logger.info("Using machine-based detection") - if self.right2left: - self.logger.info("Right-to-left mode enabled") - if self.headers_off: - self.logger.info("Headers ignored in reading order") + #if self.full_layout: + self.logger.info("Step 4/5: Reading Order Detection") - if self.reading_order_machine_based: - tror = time.time() - order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model( - contours_only_text_parent, contours_only_text_parent_h, text_regions_p) + if self.reading_order_machine_based: + self.logger.info("Using machine-based detection") + if self.right2left: + self.logger.info("Right-to-left mode enabled") + if self.headers_off: + self.logger.info("Headers ignored in reading order") + + if self.reading_order_machine_based: + order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model( + contours_only_text_parent, contours_only_text_parent_h, text_regions_p) + else: + if np.abs(slope_deskew) < SLOPE_THRESHOLD: + order_text_new, id_of_texts_tot = self.do_order_of_regions( + contours_only_text_parent, contours_only_text_parent_h, boxes, textline_mask_tot) else: - if np.abs(slope_deskew) < SLOPE_THRESHOLD: - order_text_new, id_of_texts_tot = self.do_order_of_regions( - contours_only_text_parent, contours_only_text_parent_h, boxes, textline_mask_tot) - else: - order_text_new, id_of_texts_tot = self.do_order_of_regions( - contours_only_text_parent_d_ordered, contours_only_text_parent_h_d_ordered, - boxes_d, textline_mask_tot_d) - self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s") + order_text_new, id_of_texts_tot = self.do_order_of_regions( + contours_only_text_parent_d_ordered, contours_only_text_parent_h_d_ordered, + boxes_d, textline_mask_tot_d) + self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s") - if self.ocr and not self.tr: - self.logger.info("Step 4.5/5: OCR Processing") - - if torch.cuda.is_available(): - self.logger.info("Using GPU acceleration") - else: - self.logger.info("Using CPU processing") - + if self.ocr: + self.logger.info("Step 4.5/5: OCR Processing") + + if not self.tr: gc.collect() + if len(all_found_textline_polygons)>0: ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( image_page, all_found_textline_polygons, self.prediction_model, @@ -4941,15 +4946,68 @@ class Eynollah: self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) else: ocr_all_textlines_drop = None + else: - ocr_all_textlines = None - ocr_all_textlines_marginals_left = None - ocr_all_textlines_marginals_right = None - ocr_all_textlines_h = None - ocr_all_textlines_drop = None + if self.light_version: + self.logger.info("Using light version OCR") + if self.textline_light: + self.logger.info("Using light text line detection for OCR") + self.logger.info("Processing text lines...") + + self.device.reset() + gc.collect() + + torch.cuda.empty_cache() + self.model_ocr.to(self.device) + + ind_tot = 0 + #cv2.imwrite('./img_out.png', image_page) + ocr_all_textlines = [] + for indexing, ind_poly_first in enumerate(all_found_textline_polygons): + ocr_textline_in_textregion = [] + for indexing2, ind_poly in enumerate(ind_poly_first): + if not (self.textline_light or self.curved_line): + ind_poly = copy.deepcopy(ind_poly) + box_ind = all_box_coord[indexing] + #print(ind_poly,np.shape(ind_poly), 'ind_poly') + #print(box_ind) + ind_poly = return_textline_contour_with_added_box_coordinate(ind_poly, box_ind) + #print(ind_poly_copy) + ind_poly[ind_poly<0] = 0 + x, y, w, h = cv2.boundingRect(ind_poly) + #print(ind_poly_copy, np.shape(ind_poly_copy)) + #print(x, y, w, h, h/float(w),'ratio') + h2w_ratio = h/float(w) + mask_poly = np.zeros(image_page.shape) + if not self.light_version: + img_poly_on_img = np.copy(image_page) + else: + img_poly_on_img = np.copy(img_bin_light) + mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1)) + + if self.textline_light: + mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1) + img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255 + img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255 + img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255 + + img_croped = img_poly_on_img[y:y+h, x:x+w, :] + #cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped) + text_ocr = self.return_ocr_of_textline_without_common_section( + img_croped, self.model_ocr, self.processor, self.device, w, h2w_ratio, ind_tot) + ocr_textline_in_textregion.append(text_ocr) + ind_tot = ind_tot +1 + ocr_all_textlines.append(ocr_textline_in_textregion) + else: + ocr_all_textlines = None + ocr_all_textlines_marginals_left = None + ocr_all_textlines_marginals_right = None + ocr_all_textlines_h = None + ocr_all_textlines_drop = None - self.logger.info("Step 5/5: Output Generation") - + self.logger.info("Step 5/5: Output Generation") + + if self.full_layout: pcgts = self.writer.build_pagexml_full_layout( contours_only_text_parent, contours_only_text_parent_h, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_found_textline_polygons_h, all_box_coord, all_box_coord_h, @@ -4962,129 +5020,18 @@ class Eynollah: ocr_all_textlines_marginals_left, ocr_all_textlines_marginals_right, ocr_all_textlines_drop, conf_contours_textregions, conf_contours_textregions_h) - - return pcgts - - contours_only_text_parent_h = [] - self.logger.info("Step 4/5: Reading Order Detection") - - if self.reading_order_machine_based: - self.logger.info("Using machine-based detection") - if self.right2left: - self.logger.info("Right-to-left mode enabled") - if self.headers_off: - self.logger.info("Headers ignored in reading order") - - if self.reading_order_machine_based: - order_text_new, id_of_texts_tot = self.do_order_of_regions_with_model( - contours_only_text_parent, contours_only_text_parent_h, text_regions_p) else: - if np.abs(slope_deskew) < SLOPE_THRESHOLD: - order_text_new, id_of_texts_tot = self.do_order_of_regions( - contours_only_text_parent, contours_only_text_parent_h, boxes, textline_mask_tot) - else: - contours_only_text_parent_d_ordered = self.return_list_of_contours_with_desired_order( - contours_only_text_parent_d_ordered, index_by_text_par_con) - order_text_new, id_of_texts_tot = self.do_order_of_regions( - contours_only_text_parent_d_ordered, contours_only_text_parent_h, boxes_d, textline_mask_tot_d) - - if self.ocr and self.tr: - self.logger.info("Step 4.5/5: OCR Processing") - if torch.cuda.is_available(): - self.logger.info("Using GPU acceleration") - else: - self.logger.info("Using CPU processing") - if self.light_version: - self.logger.info("Using light version OCR") - if self.textline_light: - self.logger.info("Using light text line detection for OCR") - self.logger.info("Processing text lines...") + pcgts = self.writer.build_pagexml_no_full_layout( + txt_con_org, page_coord, order_text_new, id_of_texts_tot, + all_found_textline_polygons, all_box_coord, polygons_of_images, + polygons_of_marginals_left, polygons_of_marginals_right, + all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, + all_box_coord_marginals_left, all_box_coord_marginals_right, + slopes, slopes_marginals_left, slopes_marginals_right, + cont_page, polygons_seplines, contours_tables, ocr_all_textlines, + ocr_all_textlines_marginals_left, ocr_all_textlines_marginals_right, + conf_contours_textregions) - device = cuda.get_current_device() - device.reset() - gc.collect() - model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") - torch.cuda.empty_cache() - model_ocr.to(device) - - ind_tot = 0 - #cv2.imwrite('./img_out.png', image_page) - ocr_all_textlines = [] - for indexing, ind_poly_first in enumerate(all_found_textline_polygons): - ocr_textline_in_textregion = [] - for indexing2, ind_poly in enumerate(ind_poly_first): - if not (self.textline_light or self.curved_line): - ind_poly = copy.deepcopy(ind_poly) - box_ind = all_box_coord[indexing] - #print(ind_poly,np.shape(ind_poly), 'ind_poly') - #print(box_ind) - ind_poly = return_textline_contour_with_added_box_coordinate(ind_poly, box_ind) - #print(ind_poly_copy) - ind_poly[ind_poly<0] = 0 - x, y, w, h = cv2.boundingRect(ind_poly) - #print(ind_poly_copy, np.shape(ind_poly_copy)) - #print(x, y, w, h, h/float(w),'ratio') - h2w_ratio = h/float(w) - mask_poly = np.zeros(image_page.shape) - if not self.light_version: - img_poly_on_img = np.copy(image_page) - else: - img_poly_on_img = np.copy(img_bin_light) - mask_poly = cv2.fillPoly(mask_poly, pts=[ind_poly], color=(1, 1, 1)) - - if self.textline_light: - mask_poly = cv2.dilate(mask_poly, KERNEL, iterations=1) - img_poly_on_img[:,:,0][mask_poly[:,:,0] ==0] = 255 - img_poly_on_img[:,:,1][mask_poly[:,:,0] ==0] = 255 - img_poly_on_img[:,:,2][mask_poly[:,:,0] ==0] = 255 - - img_croped = img_poly_on_img[y:y+h, x:x+w, :] - #cv2.imwrite('./extracted_lines/'+str(ind_tot)+'.jpg', img_croped) - text_ocr = self.return_ocr_of_textline_without_common_section( - img_croped, model_ocr, processor, device, w, h2w_ratio, ind_tot) - ocr_textline_in_textregion.append(text_ocr) - ind_tot = ind_tot +1 - ocr_all_textlines.append(ocr_textline_in_textregion) - - elif self.ocr and not self.tr: - gc.collect() - if len(all_found_textline_polygons)>0: - ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines( - image_page, all_found_textline_polygons, self.prediction_model, - self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) - - if all_found_textline_polygons_marginals_left and len(all_found_textline_polygons_marginals_left)>0: - ocr_all_textlines_marginals_left = return_rnn_cnn_ocr_of_given_textlines( - image_page, all_found_textline_polygons_marginals_left, self.prediction_model, - self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) - - if all_found_textline_polygons_marginals_right and len(all_found_textline_polygons_marginals_right)>0: - ocr_all_textlines_marginals_right = return_rnn_cnn_ocr_of_given_textlines( - image_page, all_found_textline_polygons_marginals_right, self.prediction_model, - self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) - - else: - ocr_all_textlines = None - ocr_all_textlines_marginals_left = None - ocr_all_textlines_marginals_right = None - self.logger.info(f"Detection of reading order took {time.time() - t_order:.1f}s") - - self.logger.info("Step 5/5: Output Generation") - self.logger.info("Generating PAGE-XML output") - - pcgts = self.writer.build_pagexml_no_full_layout( - txt_con_org, page_coord, order_text_new, id_of_texts_tot, - all_found_textline_polygons, all_box_coord, polygons_of_images, - polygons_of_marginals_left, polygons_of_marginals_right, - all_found_textline_polygons_marginals_left, all_found_textline_polygons_marginals_right, - all_box_coord_marginals_left, all_box_coord_marginals_right, - slopes, slopes_marginals_left, slopes_marginals_right, - cont_page, polygons_seplines, contours_tables, ocr_all_textlines, - ocr_all_textlines_marginals_left, ocr_all_textlines_marginals_right, - conf_contours_textregions) - return pcgts