diff --git a/src/eynollah/cli.py b/src/eynollah/cli.py index ff612b2..c306ac5 100644 --- a/src/eynollah/cli.py +++ b/src/eynollah/cli.py @@ -347,6 +347,18 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_ is_flag=True, help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.", ) +@click.option( + "--export_textline_images_and_text", + "-etit/-noetit", + is_flag=True, + help="if this parameter set to true, images and text in xml will be exported into output dir. This files can be used for training a OCR engine.", +) +@click.option( + "--do_not_mask_with_textline_contour", + "-nmtc/-mtc", + is_flag=True, + help="if this parameter set to true, cropped textline images will not be masked with textline contour.", +) @click.option( "--log_level", "-l", @@ -354,7 +366,7 @@ def layout(image, out, overwrite, dir_in, model, save_images, save_layout, save_ help="Override log level globally to this", ) -def ocr(dir_in, out, dir_xmls, model, tr_ocr, log_level): +def ocr(dir_in, out, dir_xmls, model, tr_ocr, export_textline_images_and_text, do_not_mask_with_textline_contour, log_level): if log_level: setOverrideLogLevel(log_level) initLogging() @@ -364,6 +376,8 @@ def ocr(dir_in, out, dir_xmls, model, tr_ocr, log_level): dir_out=out, dir_models=model, tr_ocr=tr_ocr, + export_textline_images_and_text=export_textline_images_and_text, + do_not_mask_with_textline_contour=do_not_mask_with_textline_contour, ) eynollah_ocr.run() diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 65e85c5..7acee39 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -4946,6 +4946,8 @@ class Eynollah_ocr: dir_in=None, dir_out=None, tr_ocr=False, + export_textline_images_and_text=False, + do_not_mask_with_textline_contour=False, logger=None, ): self.dir_in = dir_in @@ -4953,6 +4955,8 @@ class Eynollah_ocr: self.dir_xmls = dir_xmls self.dir_models = dir_models self.tr_ocr = tr_ocr + self.export_textline_images_and_text = export_textline_images_and_text + self.do_not_mask_with_textline_contour = do_not_mask_with_textline_contour if tr_ocr: self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -4961,7 +4965,7 @@ class Eynollah_ocr: self.model_ocr.to(self.device) else: - self.model_ocr_dir = dir_models + "/model_1_new_ocrcnn"#"/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn" + self.model_ocr_dir = dir_models + "/model_3_new_ocrcnn"#"/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn" model_ocr = load_model(self.model_ocr_dir , compile=False) self.prediction_model = tf.keras.models.Model( @@ -5107,7 +5111,7 @@ class Eynollah_ocr: img = cv2.imread(dir_img) ##file_name = Path(dir_xmls).stem - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding = 'iso-8859-5')) + tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] link=alltags[0].split('}')[0]+'}' @@ -5241,7 +5245,7 @@ class Eynollah_ocr: out_file_ocr = os.path.join(self.dir_out, file_name+'.xml') img = cv2.imread(dir_img) - tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding = 'iso-8859-5')) + tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding="utf-8")) root1=tree1.getroot() alltags=[elem.tag for elem in root1.iter()] link=alltags[0].split('}')[0]+'}' @@ -5257,15 +5261,16 @@ class Eynollah_ocr: tinl = time.time() indexer_text_region = 0 + indexer_textlines = 0 for nn in root1.iter(region_tags): for child_textregion in nn: if child_textregion.tag.endswith("TextLine"): - for child_textlines in child_textregion: if child_textlines.tag.endswith("Coords"): cropped_lines_region_indexer.append(indexer_text_region) p_h=child_textlines.attrib['points'].split(' ') textline_coords = np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ] for x in p_h] ) + x,y,w,h = cv2.boundingRect(textline_coords) h2w_ratio = h/float(w) @@ -5276,104 +5281,101 @@ class Eynollah_ocr: mask_poly = mask_poly[y:y+h, x:x+w, :] img_crop = img_poly_on_img[y:y+h, x:x+w, :] - img_crop[mask_poly==0] = 255 + if not self.do_not_mask_with_textline_contour: + img_crop[mask_poly==0] = 255 - if h2w_ratio > 0.05: - img_fin = self.preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(0) - else: - splited_images = self.return_textlines_split_if_needed(img_crop) - #print(splited_images) - if splited_images: - img_fin = self.preprocess_and_resize_image_for_ocrcnn_model(splited_images[0], image_height, image_width) - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(1) - img_fin = self.preprocess_and_resize_image_for_ocrcnn_model(splited_images[1], image_height, image_width) - - cropped_lines.append(img_fin) - cropped_lines_meging_indexing.append(-1) - else: + if not self.export_textline_images_and_text: + if h2w_ratio > 0.05: img_fin = self.preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width) cropped_lines.append(img_fin) cropped_lines_meging_indexing.append(0) - #img_crop = tf.reverse(img_crop,axis=[-1]) - #img_crop = self.distortion_free_resize(img_crop, img_size) - #img_crop = tf.cast(img_crop, tf.float32) / 255.0 - - #cropped_lines.append(img_crop) - - indexer_text_region = indexer_text_region +1 + else: + splited_images = self.return_textlines_split_if_needed(img_crop) + if splited_images: + img_fin = self.preprocess_and_resize_image_for_ocrcnn_model(splited_images[0], image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(1) + img_fin = self.preprocess_and_resize_image_for_ocrcnn_model(splited_images[1], image_height, image_width) + + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(-1) + else: + img_fin = self.preprocess_and_resize_image_for_ocrcnn_model(img_crop, image_height, image_width) + cropped_lines.append(img_fin) + cropped_lines_meging_indexing.append(0) + + if self.export_textline_images_and_text: + if child_textlines.tag.endswith("TextEquiv"): + for cheild_text in child_textlines: + if cheild_text.tag.endswith("Unicode"): + textline_text = cheild_text.text + if not textline_text: + pass + else: + with open(os.path.join(self.dir_out, file_name+'_line_'+str(indexer_textlines)+'.txt'), 'w') as text_file: + text_file.write(textline_text) + + cv2.imwrite(os.path.join(self.dir_out, file_name+'_line_'+str(indexer_textlines)+'.png'), img_crop ) + + indexer_textlines+=1 + + if not self.export_textline_images_and_text: + indexer_text_region = indexer_text_region +1 - - extracted_texts = [] + if not self.export_textline_images_and_text: + extracted_texts = [] - n_iterations = math.ceil(len(cropped_lines) / b_s) + n_iterations = math.ceil(len(cropped_lines) / b_s) - for i in range(n_iterations): - if i==(n_iterations-1): - n_start = i*b_s - imgs = cropped_lines[n_start:] - imgs = np.array(imgs) - imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) - else: - n_start = i*b_s - n_end = (i+1)*b_s - imgs = cropped_lines[n_start:n_end] - imgs = np.array(imgs).reshape(b_s, image_height, image_width, 3) - + for i in range(n_iterations): + if i==(n_iterations-1): + n_start = i*b_s + imgs = cropped_lines[n_start:] + imgs = np.array(imgs) + imgs = imgs.reshape(imgs.shape[0], image_height, image_width, 3) + else: + n_start = i*b_s + n_end = (i+1)*b_s + imgs = cropped_lines[n_start:n_end] + imgs = np.array(imgs).reshape(b_s, image_height, image_width, 3) + - preds = self.prediction_model.predict(imgs, verbose=0) - pred_texts = self.decode_batch_predictions(preds) + preds = self.prediction_model.predict(imgs, verbose=0) + pred_texts = self.decode_batch_predictions(preds) - for ib in range(imgs.shape[0]): - pred_texts_ib = pred_texts[ib].strip("[UNK]") - extracted_texts.append(pred_texts_ib) - - - extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))] - - extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] - #print(extracted_texts_merged, len(extracted_texts_merged)) + for ib in range(imgs.shape[0]): + pred_texts_ib = pred_texts[ib].strip("[UNK]") + extracted_texts.append(pred_texts_ib) + + extracted_texts_merged = [extracted_texts[ind] if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))] - unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) + extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] + unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - #print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer') - text_by_textregion = [] - for ind in unique_cropped_lines_region_indexer: - extracted_texts_merged_un = np.array(extracted_texts_merged)[np.array(cropped_lines_region_indexer)==ind] - - text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - - - ##unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) - - ##text_by_textregion = [] - ##for ind in unique_cropped_lines_region_indexer: - ##extracted_texts_merged_un = np.array(extracted_texts)[np.array(cropped_lines_region_indexer)==ind] - - ##text_by_textregion.append(" ".join(extracted_texts_merged_un)) - - indexer = 0 - indexer_textregion = 0 - for nn in root1.iter(region_tags): - text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') - unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - - - has_textline = False - for child_textregion in nn: - if child_textregion.tag.endswith("TextLine"): - text_subelement = ET.SubElement(child_textregion, 'TextEquiv') - unicode_textline = ET.SubElement(text_subelement, 'Unicode') - unicode_textline.text = extracted_texts_merged[indexer] - indexer = indexer + 1 - has_textline = True - if has_textline: - unicode_textregion.text = text_by_textregion[indexer_textregion] - indexer_textregion = indexer_textregion + 1 + text_by_textregion = [] + for ind in unique_cropped_lines_region_indexer: + extracted_texts_merged_un = np.array(extracted_texts_merged)[np.array(cropped_lines_region_indexer)==ind] + text_by_textregion.append(" ".join(extracted_texts_merged_un)) + + indexer = 0 + indexer_textregion = 0 + for nn in root1.iter(region_tags): + text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') + unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') - ET.register_namespace("",name_space) - tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) - #print("Job done in %.1fs", time.time() - t0) + + has_textline = False + for child_textregion in nn: + if child_textregion.tag.endswith("TextLine"): + text_subelement = ET.SubElement(child_textregion, 'TextEquiv') + unicode_textline = ET.SubElement(text_subelement, 'Unicode') + unicode_textline.text = extracted_texts_merged[indexer] + indexer = indexer + 1 + has_textline = True + if has_textline: + unicode_textregion.text = text_by_textregion[indexer_textregion] + indexer_textregion = indexer_textregion + 1 + + ET.register_namespace("",name_space) + tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) + #print("Job done in %.1fs", time.time() - t0)