threshold for textline ocr + new ocr model

This commit is contained in:
vahidrezanezhad 2025-07-25 13:18:38 +02:00
parent d968a306e4
commit 0803881f36
2 changed files with 76 additions and 49 deletions

View file

@ -496,6 +496,11 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
"-ds_pref", "-ds_pref",
help="in the case of extracting textline and text from a xml GT file user can add an abbrevation of dataset name to generated dataset", help="in the case of extracting textline and text from a xml GT file user can add an abbrevation of dataset name to generated dataset",
) )
@click.option(
"--min_conf_value_of_textline_text",
"-min_conf",
help="minimum OCR confidence value. Text lines with a confidence value lower than this threshold will not be included in the output XML file.",
)
@click.option( @click.option(
"--log_level", "--log_level",
"-l", "-l",
@ -503,7 +508,7 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
help="Override log level globally to this", help="Override log level globally to this",
) )
def ocr(image, overwrite, dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, batch_size, dataset_abbrevation, log_level): def ocr(image, overwrite, dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, batch_size, dataset_abbrevation, min_conf_value_of_textline_text, log_level):
initLogging() initLogging()
if log_level: if log_level:
getLogger('eynollah').setLevel(getLevelName(log_level)) getLogger('eynollah').setLevel(getLevelName(log_level))
@ -530,6 +535,7 @@ def ocr(image, overwrite, dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text,
prediction_with_both_of_rgb_and_bin=prediction_with_both_of_rgb_and_bin, prediction_with_both_of_rgb_and_bin=prediction_with_both_of_rgb_and_bin,
batch_size=batch_size, batch_size=batch_size,
pref_of_dataset=dataset_abbrevation, pref_of_dataset=dataset_abbrevation,
min_conf_value_of_textline_text=min_conf_value_of_textline_text,
) )
eynollah_ocr.run(overwrite=overwrite) eynollah_ocr.run(overwrite=overwrite)

View file

@ -318,7 +318,7 @@ class Eynollah:
if self.ocr and self.tr: if self.ocr and self.tr:
self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124" self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124"
elif self.ocr and not self.tr: elif self.ocr and not self.tr:
self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250716" self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250725"
if self.tables: if self.tables:
if self.light_version: if self.light_version:
self.model_table_dir = dir_models + "/modelens_table_0t4_201124" self.model_table_dir = dir_models + "/modelens_table_0t4_201124"
@ -4974,13 +4974,23 @@ class Eynollah:
gc.collect() gc.collect()
if len(all_found_textline_polygons)>0: if len(all_found_textline_polygons)>0:
ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(image_page, all_found_textline_polygons, self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) ocr_all_textlines = return_rnn_cnn_ocr_of_given_textlines(image_page, all_found_textline_polygons, self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else:
ocr_all_textlines = None
if all_found_textline_polygons_marginals and len(all_found_textline_polygons_marginals)>0: if all_found_textline_polygons_marginals and len(all_found_textline_polygons_marginals)>0:
ocr_all_textlines_marginals = return_rnn_cnn_ocr_of_given_textlines(image_page, all_found_textline_polygons_marginals, self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) ocr_all_textlines_marginals = return_rnn_cnn_ocr_of_given_textlines(image_page, all_found_textline_polygons_marginals, self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else:
ocr_all_textlines_marginals = None
if all_found_textline_polygons_h and len(all_found_textline_polygons)>0: if all_found_textline_polygons_h and len(all_found_textline_polygons)>0:
ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines(image_page, all_found_textline_polygons_h, self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) ocr_all_textlines_h = return_rnn_cnn_ocr_of_given_textlines(image_page, all_found_textline_polygons_h, self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else:
ocr_all_textlines_h = None
if polygons_of_drop_capitals and len(polygons_of_drop_capitals)>0: if polygons_of_drop_capitals and len(polygons_of_drop_capitals)>0:
ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines(image_page, polygons_of_drop_capitals, self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line) ocr_all_textlines_drop = return_rnn_cnn_ocr_of_given_textlines(image_page, polygons_of_drop_capitals, self.prediction_model, self.b_s_ocr, self.num_to_char, self.textline_light, self.curved_line)
else:
ocr_all_textlines_drop = None
else: else:
ocr_all_textlines = None ocr_all_textlines = None
ocr_all_textlines_marginals = None ocr_all_textlines_marginals = None
@ -5098,7 +5108,8 @@ class Eynollah_ocr:
do_not_mask_with_textline_contour=False, do_not_mask_with_textline_contour=False,
draw_texts_on_image=False, draw_texts_on_image=False,
prediction_with_both_of_rgb_and_bin=False, prediction_with_both_of_rgb_and_bin=False,
pref_of_dataset = None, pref_of_dataset=None,
min_conf_value_of_textline_text : Optional[float]=None,
logger=None, logger=None,
): ):
self.dir_in = dir_in self.dir_in = dir_in
@ -5117,6 +5128,10 @@ class Eynollah_ocr:
self.logger = logger if logger else getLogger('eynollah') self.logger = logger if logger else getLogger('eynollah')
if not export_textline_images_and_text: if not export_textline_images_and_text:
if min_conf_value_of_textline_text:
self.min_conf_value_of_textline_text = float(min_conf_value_of_textline_text)
else:
self.min_conf_value_of_textline_text = 0.3
if tr_ocr: if tr_ocr:
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -5129,7 +5144,7 @@ class Eynollah_ocr:
self.b_s = int(batch_size) self.b_s = int(batch_size)
else: else:
self.model_ocr_dir = dir_models + "/model_ens_ocrcnn_new6"#"/model_eynollah_ocr_cnnrnn_20250716"#"/model_ens_ocrcnn_new6"#"/model_ens_ocrcnn_new2"# self.model_ocr_dir = dir_models + "/model_eynollah_ocr_cnnrnn_20250725"#"/model_step_1020000_ocr"#"/model_ens_ocrcnn_new10"#"/model_step_255000_ocr"#"/model_ens_ocrcnn_new9"#"/model_step_900000_ocr"#"/model_eynollah_ocr_cnnrnn_20250716"#"/model_ens_ocrcnn_new6"#"/model_ens_ocrcnn_new2"#
model_ocr = load_model(self.model_ocr_dir , compile=False) model_ocr = load_model(self.model_ocr_dir , compile=False)
self.prediction_model = tf.keras.models.Model( self.prediction_model = tf.keras.models.Model(
@ -5140,8 +5155,7 @@ class Eynollah_ocr:
else: else:
self.b_s = int(batch_size) self.b_s = int(batch_size)
with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file:
with open(os.path.join(self.model_ocr_dir, "characters_20250707_all_lang.txt"),"r") as config_file:
characters = json.load(config_file) characters = json.load(config_file)
AUTOTUNE = tf.data.AUTOTUNE AUTOTUNE = tf.data.AUTOTUNE
@ -5442,51 +5456,55 @@ class Eynollah_ocr:
else: else:
#print(file_name, angle_degrees,w*h , mask_poly[:,:,0].sum(), mask_poly[:,:,0].sum() /float(w*h) , 'didi') #print(file_name, angle_degrees,w*h , mask_poly[:,:,0].sum(), mask_poly[:,:,0].sum() /float(w*h) , 'didi')
if not self.do_not_mask_with_textline_contour:
if angle_degrees > 3:
better_des_slope = get_orientation_moments(textline_coords)
img_crop = rotate_image_with_padding(img_crop, better_des_slope ) if angle_degrees > 3:
better_des_slope = get_orientation_moments(textline_coords)
if self.prediction_with_both_of_rgb_and_bin: img_crop = rotate_image_with_padding(img_crop, better_des_slope )
img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope )
mask_poly = rotate_image_with_padding(mask_poly, better_des_slope ) if self.prediction_with_both_of_rgb_and_bin:
mask_poly = mask_poly.astype('uint8') img_crop_bin = rotate_image_with_padding(img_crop_bin, better_des_slope )
#new bounding box mask_poly = rotate_image_with_padding(mask_poly, better_des_slope )
x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0]) mask_poly = mask_poly.astype('uint8')
mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :] #new bounding box
img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :] x_n, y_n, w_n, h_n = get_contours_and_bounding_boxes(mask_poly[:,:,0])
mask_poly = mask_poly[y_n:y_n+h_n, x_n:x_n+w_n, :]
img_crop = img_crop[y_n:y_n+h_n, x_n:x_n+w_n, :]
if not self.do_not_mask_with_textline_contour:
img_crop[mask_poly==0] = 255 img_crop[mask_poly==0] = 255
if self.prediction_with_both_of_rgb_and_bin: if self.prediction_with_both_of_rgb_and_bin:
img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :] img_crop_bin = img_crop_bin[y_n:y_n+h_n, x_n:x_n+w_n, :]
if not self.do_not_mask_with_textline_contour:
img_crop_bin[mask_poly==0] = 255 img_crop_bin[mask_poly==0] = 255
if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90: if mask_poly[:,:,0].sum() /float(w_n*h_n) < 0.50 and w_scaled > 90:
if self.prediction_with_both_of_rgb_and_bin:
img_crop, img_crop_bin = break_curved_line_into_small_pieces_and_then_merge(img_crop, mask_poly, img_crop_bin)
else:
img_crop, _ = break_curved_line_into_small_pieces_and_then_merge(img_crop, mask_poly)
else:
better_des_slope = 0
if not self.do_not_mask_with_textline_contour:
img_crop[mask_poly==0] = 255
if self.prediction_with_both_of_rgb_and_bin:
if not self.do_not_mask_with_textline_contour:
img_crop_bin[mask_poly==0] = 255
if type_textregion=='drop-capital':
pass
else:
if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90:
if self.prediction_with_both_of_rgb_and_bin: if self.prediction_with_both_of_rgb_and_bin:
img_crop, img_crop_bin = break_curved_line_into_small_pieces_and_then_merge(img_crop, mask_poly, img_crop_bin) img_crop, img_crop_bin = break_curved_line_into_small_pieces_and_then_merge(img_crop, mask_poly, img_crop_bin)
else: else:
img_crop, _ = break_curved_line_into_small_pieces_and_then_merge(img_crop, mask_poly) img_crop, _ = break_curved_line_into_small_pieces_and_then_merge(img_crop, mask_poly)
else:
better_des_slope = 0
img_crop[mask_poly==0] = 255
if self.prediction_with_both_of_rgb_and_bin:
img_crop_bin[mask_poly==0] = 255
if type_textregion=='drop-capital':
pass
else:
if mask_poly[:,:,0].sum() /float(w*h) < 0.50 and w_scaled > 90:
if self.prediction_with_both_of_rgb_and_bin:
img_crop, img_crop_bin = break_curved_line_into_small_pieces_and_then_merge(img_crop, mask_poly, img_crop_bin)
else:
img_crop, _ = break_curved_line_into_small_pieces_and_then_merge(img_crop, mask_poly)
if not self.export_textline_images_and_text: if not self.export_textline_images_and_text:
if w_scaled < 750:#1.5*image_width: if w_scaled < 750:#1.5*image_width:
img_fin = preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width) img_fin = preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width)
@ -5716,9 +5734,12 @@ class Eynollah_ocr:
for ib in range(imgs.shape[0]): for ib in range(imgs.shape[0]):
pred_texts_ib = pred_texts[ib].replace("[UNK]", "") pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
extracted_texts.append(pred_texts_ib) if masked_means[ib] >= self.min_conf_value_of_textline_text:
extracted_conf_value.append(masked_means[ib]) extracted_texts.append(pred_texts_ib)
extracted_conf_value.append(masked_means[ib])
else:
extracted_texts.append("")
extracted_conf_value.append(0)
del cropped_lines del cropped_lines
if self.prediction_with_both_of_rgb_and_bin: if self.prediction_with_both_of_rgb_and_bin:
del cropped_lines_bin del cropped_lines_bin
@ -5790,14 +5811,14 @@ class Eynollah_ocr:
###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)} ###id_to_order = {tid: ro for tid, ro in zip(tot_region_ref, index_tot_regions)}
id_textregions = [] #id_textregions = []
textregions_by_existing_ids = [] #textregions_by_existing_ids = []
indexer = 0 indexer = 0
indexer_textregion = 0 indexer_textregion = 0
for nn in root1.iter(region_tags): for nn in root1.iter(region_tags):
id_textregion = nn.attrib['id'] #id_textregion = nn.attrib['id']
id_textregions.append(id_textregion) #id_textregions.append(id_textregion)
textregions_by_existing_ids.append(text_by_textregion[indexer_textregion]) #textregions_by_existing_ids.append(text_by_textregion[indexer_textregion])
is_textregion_text = False is_textregion_text = False
for childtest in nn: for childtest in nn: