inference batch size for ocr is passed as an argument

This commit is contained in:
vahidrezanezhad 2025-05-02 12:53:33 +02:00
parent fd375e15d5
commit a4defbb04d
2 changed files with 41 additions and 20 deletions

View file

@ -374,6 +374,11 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
is_flag=True,
help="If this parameter is set to True, the prediction will be performed using both RGB and binary images. However, this does not necessarily improve results; it may be beneficial for certain document images.",
)
@click.option(
"--batch_size",
"-bs",
help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
)
@click.option(
"--log_level",
"-l",
@ -381,7 +386,7 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
help="Override log level globally to this",
)
def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, log_level):
def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, batch_size, log_level):
initLogging()
if log_level:
getLogger('eynollah').setLevel(getLevelName(log_level))
@ -397,6 +402,7 @@ def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, ex
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
draw_texts_on_image=draw_texts_on_image,
prediction_with_both_of_rgb_and_bin=prediction_with_both_of_rgb_and_bin,
batch_size=batch_size,
)
eynollah_ocr.run()

View file

@ -4872,6 +4872,7 @@ class Eynollah_ocr:
dir_out=None,
dir_out_image_text=None,
tr_ocr=False,
batch_size=None,
export_textline_images_and_text=False,
do_not_mask_with_textline_contour=False,
draw_texts_on_image=False,
@ -4895,6 +4896,10 @@ class Eynollah_ocr:
self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124"
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
self.model_ocr.to(self.device)
if not batch_size:
self.b_s = 2
else:
self.b_s = int(batch_size)
else:
self.model_ocr_dir = dir_models + "/model_step_1050000_ocr"#"/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn"
@ -4903,6 +4908,10 @@ class Eynollah_ocr:
self.prediction_model = tf.keras.models.Model(
model_ocr.get_layer(name = "image").input,
model_ocr.get_layer(name = "dense2").output)
if not batch_size:
self.b_s = 8
else:
self.b_s = int(batch_size)
with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file:
@ -4918,6 +4927,7 @@ class Eynollah_ocr:
self.num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
def decode_batch_predictions(self, pred, max_len = 128):
# input_len is the product of the batch size and the
@ -5073,10 +5083,9 @@ class Eynollah_ocr:
ls_imgs = os.listdir(self.dir_in)
if self.tr_ocr:
b_s = 2
tr_ocr_input_height_and_width = 384
for ind_img in ls_imgs:
t0 = time.time()
file_name = ind_img.split('.')[0]
file_name = Path(ind_img).stem
dir_img = os.path.join(self.dir_in, ind_img)
dir_xml = os.path.join(self.dir_xmls, file_name+'.xml')
out_file_ocr = os.path.join(self.dir_out, file_name+'.xml')
@ -5131,15 +5140,15 @@ class Eynollah_ocr:
img_crop[mask_poly==0] = 255
if h2w_ratio > 0.1:
cropped_lines.append(img_crop)
cropped_lines.append(resize_image(img_crop, tr_ocr_input_height_and_width, tr_ocr_input_height_and_width) )
cropped_lines_meging_indexing.append(0)
else:
splited_images, _ = self.return_textlines_split_if_needed(img_crop, None)
#print(splited_images)
if splited_images:
cropped_lines.append(splited_images[0])
cropped_lines.append(resize_image(splited_images[0], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(1)
cropped_lines.append(splited_images[1])
cropped_lines.append(resize_image(splited_images[1], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width))
cropped_lines_meging_indexing.append(-1)
else:
cropped_lines.append(img_crop)
@ -5148,21 +5157,24 @@ class Eynollah_ocr:
extracted_texts = []
n_iterations = math.ceil(len(cropped_lines) / b_s)
n_iterations = math.ceil(len(cropped_lines) / self.b_s)
for i in range(n_iterations):
if i==(n_iterations-1):
n_start = i*b_s
n_start = i*self.b_s
imgs = cropped_lines[n_start:]
else:
n_start = i*b_s
n_end = (i+1)*b_s
n_start = i*self.b_s
n_end = (i+1)*self.b_s
imgs = cropped_lines[n_start:n_end]
pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values
generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device))
generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
extracted_texts = extracted_texts + generated_text_merged
del cropped_lines
gc.collect()
extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+" "+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))]
@ -5241,14 +5253,12 @@ class Eynollah_ocr:
padding_token = 299
image_width = 512#max_len * 4
image_height = 32
b_s = 8
img_size=(image_width, image_height)
for ind_img in ls_imgs:
t0 = time.time()
file_name = ind_img.split('.')[0]
file_name = Path(ind_img).stem
dir_img = os.path.join(self.dir_in, ind_img)
dir_xml = os.path.join(self.dir_xmls, file_name+'.xml')
out_file_ocr = os.path.join(self.dir_out, file_name+'.xml')
@ -5368,11 +5378,11 @@ class Eynollah_ocr:
if not self.export_textline_images_and_text:
extracted_texts = []
n_iterations = math.ceil(len(cropped_lines) / b_s)
n_iterations = math.ceil(len(cropped_lines) / self.b_s)
for i in range(n_iterations):
if i==(n_iterations-1):
n_start = i*b_s
n_start = i*self.b_s
imgs = cropped_lines[n_start:]
imgs = np.array(imgs)
imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3)
@ -5381,14 +5391,14 @@ class Eynollah_ocr:
imgs_bin = np.array(imgs_bin)
imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3)
else:
n_start = i*b_s
n_end = (i+1)*b_s
n_start = i*self.b_s
n_end = (i+1)*self.b_s
imgs = cropped_lines[n_start:n_end]
imgs = np.array(imgs).reshape(b_s, image_height, image_width, 3)
imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3)
if self.prediction_with_both_of_rgb_and_bin:
imgs_bin = cropped_lines_bin[n_start:n_end]
imgs_bin = np.array(imgs_bin).reshape(b_s, image_height, image_width, 3)
imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3)
preds = self.prediction_model.predict(imgs, verbose=0)
@ -5402,6 +5412,11 @@ class Eynollah_ocr:
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
extracted_texts.append(pred_texts_ib)
del cropped_lines
if self.prediction_with_both_of_rgb_and_bin:
del cropped_lines_bin
gc.collect()
extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+" "+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))]
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]