diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index c189aca..56d5d7e 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -374,6 +374,11 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_ is_flag=True, help="If this parameter is set to True, the prediction will be performed using both RGB and binary images. However, this does not necessarily improve results; it may be beneficial for certain document images.", ) +@click.option( + "--batch_size", + "-bs", + help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively", +) @click.option( "--log_level", "-l", @@ -381,7 +386,7 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_ help="Override log level globally to this", ) -def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, log_level): +def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, batch_size, log_level): initLogging() if log_level: getLogger('eynollah').setLevel(getLevelName(log_level)) @@ -397,6 +402,7 @@ def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, ex do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, draw_texts_on_image=draw_texts_on_image, prediction_with_both_of_rgb_and_bin=prediction_with_both_of_rgb_and_bin, + batch_size=batch_size, ) eynollah_ocr.run() diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index d148c67..62026bf 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -4872,6 +4872,7 @@ class Eynollah_ocr: dir_out=None, dir_out_image_text=None, tr_ocr=False, + batch_size=None, export_textline_images_and_text=False, do_not_mask_with_textline_contour=False, draw_texts_on_image=False, @@ -4895,6 +4896,10 @@ class Eynollah_ocr: self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124" self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) self.model_ocr.to(self.device) + if not batch_size: + self.b_s = 2 + else: + self.b_s = int(batch_size) else: self.model_ocr_dir = dir_models + "/model_step_1050000_ocr"#"/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn" @@ -4903,6 +4908,10 @@ class Eynollah_ocr: self.prediction_model = tf.keras.models.Model( model_ocr.get_layer(name = "image").input, model_ocr.get_layer(name = "dense2").output) + if not batch_size: + self.b_s = 8 + else: + self.b_s = int(batch_size) with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file: @@ -4918,6 +4927,7 @@ class Eynollah_ocr: self.num_to_char = StringLookup( vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True ) + def decode_batch_predictions(self, pred, max_len = 128): # input_len is the product of the batch size and the @@ -5073,10 +5083,9 @@ class Eynollah_ocr: ls_imgs = os.listdir(self.dir_in) if self.tr_ocr: - b_s = 2 + tr_ocr_input_height_and_width = 384 for ind_img in ls_imgs: - t0 = time.time() - file_name = ind_img.split('.')[0] + file_name = Path(ind_img).stem dir_img = os.path.join(self.dir_in, ind_img) dir_xml = os.path.join(self.dir_xmls, file_name+'.xml') out_file_ocr = os.path.join(self.dir_out, file_name+'.xml') @@ -5131,15 +5140,15 @@ class Eynollah_ocr: img_crop[mask_poly==0] = 255 if h2w_ratio > 0.1: - cropped_lines.append(img_crop) + cropped_lines.append(resize_image(img_crop, tr_ocr_input_height_and_width, tr_ocr_input_height_and_width) ) cropped_lines_meging_indexing.append(0) else: splited_images, _ = self.return_textlines_split_if_needed(img_crop, None) #print(splited_images) if splited_images: - cropped_lines.append(splited_images[0]) + cropped_lines.append(resize_image(splited_images[0], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width)) cropped_lines_meging_indexing.append(1) - cropped_lines.append(splited_images[1]) + cropped_lines.append(resize_image(splited_images[1], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width)) cropped_lines_meging_indexing.append(-1) else: cropped_lines.append(img_crop) @@ -5148,21 +5157,24 @@ class Eynollah_ocr: extracted_texts = [] - n_iterations = math.ceil(len(cropped_lines) / b_s) + n_iterations = math.ceil(len(cropped_lines) / self.b_s) for i in range(n_iterations): if i==(n_iterations-1): - n_start = i*b_s + n_start = i*self.b_s imgs = cropped_lines[n_start:] else: - n_start = i*b_s - n_end = (i+1)*b_s + n_start = i*self.b_s + n_end = (i+1)*self.b_s imgs = cropped_lines[n_start:n_end] pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) extracted_texts = extracted_texts + generated_text_merged + + del cropped_lines + gc.collect() extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+" "+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))] @@ -5241,14 +5253,12 @@ class Eynollah_ocr: padding_token = 299 image_width = 512#max_len * 4 image_height = 32 - b_s = 8 img_size=(image_width, image_height) for ind_img in ls_imgs: - t0 = time.time() - file_name = ind_img.split('.')[0] + file_name = Path(ind_img).stem dir_img = os.path.join(self.dir_in, ind_img) dir_xml = os.path.join(self.dir_xmls, file_name+'.xml') out_file_ocr = os.path.join(self.dir_out, file_name+'.xml') @@ -5368,11 +5378,11 @@ class Eynollah_ocr: if not self.export_textline_images_and_text: extracted_texts = [] - n_iterations = math.ceil(len(cropped_lines) / b_s) + n_iterations = math.ceil(len(cropped_lines) / self.b_s) for i in range(n_iterations): if i==(n_iterations-1): - n_start = i*b_s + n_start = i*self.b_s imgs = cropped_lines[n_start:] imgs = np.array(imgs) imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) @@ -5381,14 +5391,14 @@ class Eynollah_ocr: imgs_bin = np.array(imgs_bin) imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3) else: - n_start = i*b_s - n_end = (i+1)*b_s + n_start = i*self.b_s + n_end = (i+1)*self.b_s imgs = cropped_lines[n_start:n_end] - imgs = np.array(imgs).reshape(b_s, image_height, image_width, 3) + imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3) if self.prediction_with_both_of_rgb_and_bin: imgs_bin = cropped_lines_bin[n_start:n_end] - imgs_bin = np.array(imgs_bin).reshape(b_s, image_height, image_width, 3) + imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3) preds = self.prediction_model.predict(imgs, verbose=0) @@ -5402,6 +5412,11 @@ class Eynollah_ocr: pred_texts_ib = pred_texts[ib].replace("[UNK]", "") extracted_texts.append(pred_texts_ib) + del cropped_lines + if self.prediction_with_both_of_rgb_and_bin: + del cropped_lines_bin + gc.collect() + extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+" "+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))] extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]