mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-08-29 03:49:54 +02:00
inference batch size for ocr is passed as an argument
This commit is contained in:
parent
fd375e15d5
commit
a4defbb04d
2 changed files with 41 additions and 20 deletions
|
@ -374,6 +374,11 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
|||
is_flag=True,
|
||||
help="If this parameter is set to True, the prediction will be performed using both RGB and binary images. However, this does not necessarily improve results; it may be beneficial for certain document images.",
|
||||
)
|
||||
@click.option(
|
||||
"--batch_size",
|
||||
"-bs",
|
||||
help="number of inference batch size. Default b_s for trocr and cnn_rnn models are 2 and 8 respectively",
|
||||
)
|
||||
@click.option(
|
||||
"--log_level",
|
||||
"-l",
|
||||
|
@ -381,7 +386,7 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_
|
|||
help="Override log level globally to this",
|
||||
)
|
||||
|
||||
def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, log_level):
|
||||
def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, draw_texts_on_image, prediction_with_both_of_rgb_and_bin, batch_size, log_level):
|
||||
initLogging()
|
||||
if log_level:
|
||||
getLogger('eynollah').setLevel(getLevelName(log_level))
|
||||
|
@ -397,6 +402,7 @@ def ocr(dir_in, dir_in_bin, out, dir_xmls, dir_out_image_text, model, tr_ocr, ex
|
|||
do_not_mask_with_textline_contour=do_not_mask_with_textline_contour,
|
||||
draw_texts_on_image=draw_texts_on_image,
|
||||
prediction_with_both_of_rgb_and_bin=prediction_with_both_of_rgb_and_bin,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
eynollah_ocr.run()
|
||||
|
||||
|
|
|
@ -4872,6 +4872,7 @@ class Eynollah_ocr:
|
|||
dir_out=None,
|
||||
dir_out_image_text=None,
|
||||
tr_ocr=False,
|
||||
batch_size=None,
|
||||
export_textline_images_and_text=False,
|
||||
do_not_mask_with_textline_contour=False,
|
||||
draw_texts_on_image=False,
|
||||
|
@ -4895,6 +4896,10 @@ class Eynollah_ocr:
|
|||
self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124"
|
||||
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
|
||||
self.model_ocr.to(self.device)
|
||||
if not batch_size:
|
||||
self.b_s = 2
|
||||
else:
|
||||
self.b_s = int(batch_size)
|
||||
|
||||
else:
|
||||
self.model_ocr_dir = dir_models + "/model_step_1050000_ocr"#"/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn"
|
||||
|
@ -4903,6 +4908,10 @@ class Eynollah_ocr:
|
|||
self.prediction_model = tf.keras.models.Model(
|
||||
model_ocr.get_layer(name = "image").input,
|
||||
model_ocr.get_layer(name = "dense2").output)
|
||||
if not batch_size:
|
||||
self.b_s = 8
|
||||
else:
|
||||
self.b_s = int(batch_size)
|
||||
|
||||
|
||||
with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file:
|
||||
|
@ -4918,6 +4927,7 @@ class Eynollah_ocr:
|
|||
self.num_to_char = StringLookup(
|
||||
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
|
||||
)
|
||||
|
||||
|
||||
def decode_batch_predictions(self, pred, max_len = 128):
|
||||
# input_len is the product of the batch size and the
|
||||
|
@ -5073,10 +5083,9 @@ class Eynollah_ocr:
|
|||
ls_imgs = os.listdir(self.dir_in)
|
||||
|
||||
if self.tr_ocr:
|
||||
b_s = 2
|
||||
tr_ocr_input_height_and_width = 384
|
||||
for ind_img in ls_imgs:
|
||||
t0 = time.time()
|
||||
file_name = ind_img.split('.')[0]
|
||||
file_name = Path(ind_img).stem
|
||||
dir_img = os.path.join(self.dir_in, ind_img)
|
||||
dir_xml = os.path.join(self.dir_xmls, file_name+'.xml')
|
||||
out_file_ocr = os.path.join(self.dir_out, file_name+'.xml')
|
||||
|
@ -5131,15 +5140,15 @@ class Eynollah_ocr:
|
|||
img_crop[mask_poly==0] = 255
|
||||
|
||||
if h2w_ratio > 0.1:
|
||||
cropped_lines.append(img_crop)
|
||||
cropped_lines.append(resize_image(img_crop, tr_ocr_input_height_and_width, tr_ocr_input_height_and_width) )
|
||||
cropped_lines_meging_indexing.append(0)
|
||||
else:
|
||||
splited_images, _ = self.return_textlines_split_if_needed(img_crop, None)
|
||||
#print(splited_images)
|
||||
if splited_images:
|
||||
cropped_lines.append(splited_images[0])
|
||||
cropped_lines.append(resize_image(splited_images[0], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width))
|
||||
cropped_lines_meging_indexing.append(1)
|
||||
cropped_lines.append(splited_images[1])
|
||||
cropped_lines.append(resize_image(splited_images[1], tr_ocr_input_height_and_width, tr_ocr_input_height_and_width))
|
||||
cropped_lines_meging_indexing.append(-1)
|
||||
else:
|
||||
cropped_lines.append(img_crop)
|
||||
|
@ -5148,21 +5157,24 @@ class Eynollah_ocr:
|
|||
|
||||
|
||||
extracted_texts = []
|
||||
n_iterations = math.ceil(len(cropped_lines) / b_s)
|
||||
n_iterations = math.ceil(len(cropped_lines) / self.b_s)
|
||||
|
||||
for i in range(n_iterations):
|
||||
if i==(n_iterations-1):
|
||||
n_start = i*b_s
|
||||
n_start = i*self.b_s
|
||||
imgs = cropped_lines[n_start:]
|
||||
else:
|
||||
n_start = i*b_s
|
||||
n_end = (i+1)*b_s
|
||||
n_start = i*self.b_s
|
||||
n_end = (i+1)*self.b_s
|
||||
imgs = cropped_lines[n_start:n_end]
|
||||
pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values
|
||||
generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device))
|
||||
generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True)
|
||||
|
||||
extracted_texts = extracted_texts + generated_text_merged
|
||||
|
||||
del cropped_lines
|
||||
gc.collect()
|
||||
|
||||
extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+" "+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
|
@ -5241,14 +5253,12 @@ class Eynollah_ocr:
|
|||
padding_token = 299
|
||||
image_width = 512#max_len * 4
|
||||
image_height = 32
|
||||
b_s = 8
|
||||
|
||||
|
||||
img_size=(image_width, image_height)
|
||||
|
||||
for ind_img in ls_imgs:
|
||||
t0 = time.time()
|
||||
file_name = ind_img.split('.')[0]
|
||||
file_name = Path(ind_img).stem
|
||||
dir_img = os.path.join(self.dir_in, ind_img)
|
||||
dir_xml = os.path.join(self.dir_xmls, file_name+'.xml')
|
||||
out_file_ocr = os.path.join(self.dir_out, file_name+'.xml')
|
||||
|
@ -5368,11 +5378,11 @@ class Eynollah_ocr:
|
|||
if not self.export_textline_images_and_text:
|
||||
extracted_texts = []
|
||||
|
||||
n_iterations = math.ceil(len(cropped_lines) / b_s)
|
||||
n_iterations = math.ceil(len(cropped_lines) / self.b_s)
|
||||
|
||||
for i in range(n_iterations):
|
||||
if i==(n_iterations-1):
|
||||
n_start = i*b_s
|
||||
n_start = i*self.b_s
|
||||
imgs = cropped_lines[n_start:]
|
||||
imgs = np.array(imgs)
|
||||
imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3)
|
||||
|
@ -5381,14 +5391,14 @@ class Eynollah_ocr:
|
|||
imgs_bin = np.array(imgs_bin)
|
||||
imgs_bin = imgs_bin.reshape(imgs_bin.shape[0], image_height, image_width, 3)
|
||||
else:
|
||||
n_start = i*b_s
|
||||
n_end = (i+1)*b_s
|
||||
n_start = i*self.b_s
|
||||
n_end = (i+1)*self.b_s
|
||||
imgs = cropped_lines[n_start:n_end]
|
||||
imgs = np.array(imgs).reshape(b_s, image_height, image_width, 3)
|
||||
imgs = np.array(imgs).reshape(self.b_s, image_height, image_width, 3)
|
||||
|
||||
if self.prediction_with_both_of_rgb_and_bin:
|
||||
imgs_bin = cropped_lines_bin[n_start:n_end]
|
||||
imgs_bin = np.array(imgs_bin).reshape(b_s, image_height, image_width, 3)
|
||||
imgs_bin = np.array(imgs_bin).reshape(self.b_s, image_height, image_width, 3)
|
||||
|
||||
|
||||
preds = self.prediction_model.predict(imgs, verbose=0)
|
||||
|
@ -5402,6 +5412,11 @@ class Eynollah_ocr:
|
|||
pred_texts_ib = pred_texts[ib].replace("[UNK]", "")
|
||||
extracted_texts.append(pred_texts_ib)
|
||||
|
||||
del cropped_lines
|
||||
if self.prediction_with_both_of_rgb_and_bin:
|
||||
del cropped_lines_bin
|
||||
gc.collect()
|
||||
|
||||
extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+" "+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))]
|
||||
|
||||
extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue