mirror of
				https://github.com/qurator-spk/eynollah.git
				synced 2025-10-27 15:54:13 +01:00 
			
		
		
		
	Provide OCR as an option to process a directory of XML files, incorporating layout and text line coordinates.
This commit is contained in:
		
							parent
							
								
									fbeef79d50
								
							
						
					
					
						commit
						92bfac4b41
					
				
					 2 changed files with 459 additions and 11 deletions
				
			
		|  | @ -1,7 +1,7 @@ | |||
| import sys | ||||
| import click | ||||
| from ocrd_utils import initLogging, setOverrideLogLevel | ||||
| from eynollah.eynollah import Eynollah | ||||
| from eynollah.eynollah import Eynollah, Eynollah_ocr | ||||
| from eynollah.sbb_binarize import SbbBinarizer | ||||
| 
 | ||||
| @click.group() | ||||
|  | @ -305,6 +305,60 @@ def layout(image, out, dir_in, model, save_images, save_layout, save_deskewed, s | |||
|     else: | ||||
|         pcgts = eynollah.run() | ||||
|         eynollah.writer.write_pagexml(pcgts) | ||||
|          | ||||
|          | ||||
| @main.command() | ||||
| @click.option( | ||||
|     "--dir_in", | ||||
|     "-di", | ||||
|     help="directory of images", | ||||
|     type=click.Path(exists=True, file_okay=False), | ||||
| ) | ||||
| @click.option( | ||||
|     "--out", | ||||
|     "-o", | ||||
|     help="directory to write output xml data", | ||||
|     type=click.Path(exists=True, file_okay=False), | ||||
|     required=True, | ||||
| ) | ||||
| @click.option( | ||||
|     "--dir_xmls", | ||||
|     "-dx", | ||||
|     help="directory of xmls", | ||||
|     type=click.Path(exists=True, file_okay=False), | ||||
| ) | ||||
| @click.option( | ||||
|     "--model", | ||||
|     "-m", | ||||
|     help="directory of models", | ||||
|     type=click.Path(exists=True, file_okay=False), | ||||
|     required=True, | ||||
| ) | ||||
| @click.option( | ||||
|     "--tr_ocr", | ||||
|     "-trocr/-notrocr", | ||||
|     is_flag=True, | ||||
|     help="if this parameter set to true, transformer ocr will be applied, otherwise cnn_rnn model.", | ||||
| ) | ||||
| @click.option( | ||||
|     "--log_level", | ||||
|     "-l", | ||||
|     type=click.Choice(['OFF', 'DEBUG', 'INFO', 'WARN', 'ERROR']), | ||||
|     help="Override log level globally to this", | ||||
| ) | ||||
| 
 | ||||
| def ocr(dir_in, out, dir_xmls, model, tr_ocr, log_level): | ||||
|     if log_level: | ||||
|         setOverrideLogLevel(log_level) | ||||
|     initLogging() | ||||
|     eynollah_ocr = Eynollah_ocr( | ||||
|         dir_xmls=dir_xmls, | ||||
|         dir_in=dir_in, | ||||
|         dir_out=out, | ||||
|         dir_models=model, | ||||
|         tr_ocr=tr_ocr, | ||||
|     ) | ||||
|     eynollah_ocr.run() | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  |  | |||
|  | @ -41,6 +41,9 @@ import matplotlib.pyplot as plt | |||
| # use tf1 compatibility for keras backend | ||||
| from tensorflow.compat.v1.keras.backend import set_session | ||||
| from tensorflow.keras import layers | ||||
| import json | ||||
| import xml.etree.ElementTree as ET | ||||
| from tensorflow.keras.layers import StringLookup | ||||
| 
 | ||||
| from .utils.contour import ( | ||||
|     filter_contours_area_of_image, | ||||
|  | @ -2188,18 +2191,18 @@ class Eynollah: | |||
|         img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w)) | ||||
|          | ||||
|         if not self.dir_in: | ||||
|             prediction_textline = self.do_prediction(patches, img, model_textline, marginal_of_patch_percent=0.15, n_batch_inference=3, thresholding_for_artificial_class_in_light_version=thresholding_for_artificial_class_in_light_version) | ||||
|             ###prediction_textline = self.do_prediction(patches, img, model_textline, marginal_of_patch_percent=0.15, n_batch_inference=3, thresholding_for_artificial_class_in_light_version=thresholding_for_artificial_class_in_light_version) | ||||
|              | ||||
|             ##prediction_textline = self.do_prediction_new_concept_scatter_nd(patches, img, model_textline, n_batch_inference=3) | ||||
|             prediction_textline = self.do_prediction_new_concept_scatter_nd(patches, img, model_textline, n_batch_inference=3) | ||||
|              | ||||
|             #if not thresholding_for_artificial_class_in_light_version: | ||||
|                 #if num_col_classifier==1: | ||||
|                     #prediction_textline_nopatch = self.do_prediction(False, img, model_textline) | ||||
|                     #prediction_textline[:,:][prediction_textline_nopatch[:,:]==0] = 0 | ||||
|         else: | ||||
|             prediction_textline = self.do_prediction(patches, img, self.model_textline, marginal_of_patch_percent=0.15, n_batch_inference=3,thresholding_for_artificial_class_in_light_version=thresholding_for_artificial_class_in_light_version) | ||||
|             ##prediction_textline = self.do_prediction(patches, img, self.model_textline, marginal_of_patch_percent=0.15, n_batch_inference=3,thresholding_for_artificial_class_in_light_version=thresholding_for_artificial_class_in_light_version) | ||||
|              | ||||
|             ###prediction_textline = self.do_prediction_new_concept_scatter_nd(patches, img, self.model_textline, n_batch_inference=3) | ||||
|             prediction_textline = self.do_prediction_new_concept_scatter_nd(patches, img, self.model_textline, n_batch_inference=3) | ||||
|             #if not thresholding_for_artificial_class_in_light_version: | ||||
|                 #if num_col_classifier==1: | ||||
|                     #prediction_textline_nopatch = self.do_prediction(False, img, model_textline) | ||||
|  | @ -2479,17 +2482,17 @@ class Eynollah: | |||
|                 if num_col_classifier == 1 or num_col_classifier == 2: | ||||
|                     model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_1_2_sp_np) | ||||
|                     if self.image_org.shape[0]/self.image_org.shape[1] > 2.5: | ||||
|                         ##prediction_regions_org = self.do_prediction_new_concept_scatter_nd(True, img_resized, model_region, n_batch_inference=1, thresholding_for_some_classes_in_light_version = True) | ||||
|                         prediction_regions_org = self.do_prediction_new_concept(True, img_resized, model_region, n_batch_inference=1, thresholding_for_some_classes_in_light_version = True) | ||||
|                         prediction_regions_org = self.do_prediction_new_concept_scatter_nd(True, img_resized, model_region, n_batch_inference=1, thresholding_for_some_classes_in_light_version = True) | ||||
|                         ###prediction_regions_org = self.do_prediction_new_concept(True, img_resized, model_region, n_batch_inference=1, thresholding_for_some_classes_in_light_version = True) | ||||
|                     else: | ||||
|                         prediction_regions_org = np.zeros((self.image_org.shape[0], self.image_org.shape[1], 3)) | ||||
|                         ##prediction_regions_page = self.do_prediction_new_concept_scatter_nd(False, self.image_page_org_size, model_region, n_batch_inference=1, thresholding_for_artificial_class_in_light_version = True) | ||||
|                         prediction_regions_page = self.do_prediction_new_concept(False, self.image_page_org_size, model_region, n_batch_inference=1, thresholding_for_artificial_class_in_light_version = True) | ||||
|                         prediction_regions_page = self.do_prediction_new_concept_scatter_nd(False, self.image_page_org_size, model_region, n_batch_inference=1, thresholding_for_artificial_class_in_light_version = True) | ||||
|                         ##prediction_regions_page = self.do_prediction_new_concept(False, self.image_page_org_size, model_region, n_batch_inference=1, thresholding_for_artificial_class_in_light_version = True) | ||||
|                         prediction_regions_org[self.page_coord[0] : self.page_coord[1], self.page_coord[2] : self.page_coord[3],:] = prediction_regions_page | ||||
|                 else: | ||||
|                     model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_1_2_sp_np) | ||||
|                     prediction_regions_org = self.do_prediction_new_concept(True, resize_image(img_bin, int( (900+ (num_col_classifier-3)*100) *(img_bin.shape[0]/img_bin.shape[1]) ), 900+ (num_col_classifier-3)*100), model_region, n_batch_inference=2, thresholding_for_some_classes_in_light_version=True) | ||||
|                     ###prediction_regions_org = self.do_prediction_new_concept_scatter_nd(True, resize_image(img_bin, int( (900+ (num_col_classifier-3)*100) *(img_bin.shape[0]/img_bin.shape[1]) ), 900+ (num_col_classifier-3)*100), model_region, n_batch_inference=2, thresholding_for_some_classes_in_light_version=True) | ||||
|                     ###prediction_regions_org = self.do_prediction_new_concept(True, resize_image(img_bin, int( (900+ (num_col_classifier-3)*100) *(img_bin.shape[0]/img_bin.shape[1]) ), 900+ (num_col_classifier-3)*100), model_region, n_batch_inference=2, thresholding_for_some_classes_in_light_version=True) | ||||
|                     prediction_regions_org = self.do_prediction_new_concept_scatter_nd(True, resize_image(img_bin, int( (900+ (num_col_classifier-3)*100) *(img_bin.shape[0]/img_bin.shape[1]) ), 900+ (num_col_classifier-3)*100), model_region, n_batch_inference=2, thresholding_for_some_classes_in_light_version=True) | ||||
|                 ##model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens_light) | ||||
|                 ##prediction_regions_org = self.do_prediction(True, img_bin, model_region, n_batch_inference=3, thresholding_for_some_classes_in_light_version=True) | ||||
|             else: | ||||
|  | @ -5610,3 +5613,394 @@ class Eynollah: | |||
|         if self.dir_in: | ||||
|             self.logger.info("All jobs done in %.1fs", time.time() - t0_tot) | ||||
|             print("all Job done in %.1fs", time.time() - t0_tot) | ||||
|              | ||||
|              | ||||
| class Eynollah_ocr: | ||||
|     def __init__( | ||||
|         self, | ||||
|         dir_models, | ||||
|         dir_xmls=None, | ||||
|         dir_in=None, | ||||
|         dir_out=None, | ||||
|         tr_ocr=False, | ||||
|         logger=None, | ||||
|     ): | ||||
|         self.dir_in = dir_in | ||||
|         self.dir_out = dir_out | ||||
|         self.dir_xmls = dir_xmls | ||||
|         self.dir_models = dir_models | ||||
|         self.tr_ocr = tr_ocr | ||||
|         if tr_ocr: | ||||
|             self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") | ||||
|             self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||||
|             self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124" | ||||
|             self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir) | ||||
|             self.model_ocr.to(self.device) | ||||
| 
 | ||||
|         else: | ||||
|             self.model_ocr_dir = dir_models + "/model_0_ocr_cnnrnn"#"/model_23_ocr_cnnrnn" | ||||
|             model_ocr = load_model(self.model_ocr_dir , compile=False) | ||||
|              | ||||
|             self.prediction_model = tf.keras.models.Model( | ||||
|                             model_ocr.get_layer(name = "image").input,  | ||||
|                             model_ocr.get_layer(name = "dense2").output) | ||||
| 
 | ||||
|                  | ||||
|             with open(os.path.join(self.model_ocr_dir, "characters_org.txt"),"r") as config_file: | ||||
|                 characters = json.load(config_file) | ||||
| 
 | ||||
|                  | ||||
|             AUTOTUNE = tf.data.AUTOTUNE | ||||
| 
 | ||||
|             # Mapping characters to integers. | ||||
|             char_to_num = StringLookup(vocabulary=list(characters), mask_token=None) | ||||
| 
 | ||||
|             # Mapping integers back to original characters. | ||||
|             self.num_to_char = StringLookup( | ||||
|                 vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True | ||||
|             ) | ||||
|          | ||||
|     def decode_batch_predictions(self, pred, max_len = 128): | ||||
|         # input_len is the product of the batch size and the | ||||
|         # number of time steps. | ||||
|         input_len = np.ones(pred.shape[0]) * pred.shape[1] | ||||
|          | ||||
|         # Decode CTC predictions using greedy search. | ||||
|         # decoded is a tuple with 2 elements. | ||||
|         decoded = tf.keras.backend.ctc_decode(pred,  | ||||
|                         input_length = input_len,  | ||||
|                                     beam_width = 100) | ||||
|         # The outputs are in the first element of the tuple. | ||||
|         # Additionally, the first element is actually a list, | ||||
|         # therefore we take the first element of that list as well. | ||||
|         #print(decoded,'decoded') | ||||
|         decoded = decoded[0][0][:, :max_len] | ||||
|          | ||||
|         #print(decoded, decoded.shape,'decoded') | ||||
| 
 | ||||
|         output = [] | ||||
|         for d in decoded: | ||||
|             # Convert the predicted indices to the corresponding chars. | ||||
|             d = tf.strings.reduce_join(self.num_to_char(d)) | ||||
|             d = d.numpy().decode("utf-8") | ||||
|             output.append(d) | ||||
|         return output | ||||
|          | ||||
|          | ||||
|     def distortion_free_resize(self, image, img_size): | ||||
|         w, h = img_size | ||||
|         image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True) | ||||
| 
 | ||||
|         # Check tha amount of padding needed to be done. | ||||
|         pad_height = h - tf.shape(image)[0] | ||||
|         pad_width = w - tf.shape(image)[1] | ||||
| 
 | ||||
|         # Only necessary if you want to do same amount of padding on both sides. | ||||
|         if pad_height % 2 != 0: | ||||
|             height = pad_height // 2 | ||||
|             pad_height_top = height + 1 | ||||
|             pad_height_bottom = height | ||||
|         else: | ||||
|             pad_height_top = pad_height_bottom = pad_height // 2 | ||||
| 
 | ||||
|         if pad_width % 2 != 0: | ||||
|             width = pad_width // 2 | ||||
|             pad_width_left = width + 1 | ||||
|             pad_width_right = width | ||||
|         else: | ||||
|             pad_width_left = pad_width_right = pad_width // 2 | ||||
| 
 | ||||
|         image = tf.pad( | ||||
|             image, | ||||
|             paddings=[ | ||||
|                 [pad_height_top, pad_height_bottom], | ||||
|                 [pad_width_left, pad_width_right], | ||||
|                 [0, 0], | ||||
|             ], | ||||
|         ) | ||||
| 
 | ||||
|         image = tf.transpose(image, (1, 0, 2)) | ||||
|         image = tf.image.flip_left_right(image) | ||||
|         return image | ||||
|      | ||||
|     def return_start_and_end_of_common_text_of_textline_ocr_without_common_section(self, textline_image): | ||||
|         width = np.shape(textline_image)[1] | ||||
|         height = np.shape(textline_image)[0] | ||||
|         common_window = int(0.06*width) | ||||
| 
 | ||||
|         width1 = int ( width/2. - common_window ) | ||||
|         width2 = int ( width/2. + common_window ) | ||||
|          | ||||
|         img_sum = np.sum(textline_image[:,:,0], axis=0) | ||||
|         sum_smoothed = gaussian_filter1d(img_sum, 3) | ||||
|          | ||||
|         peaks_real, _ = find_peaks(sum_smoothed, height=0) | ||||
|          | ||||
|         if len(peaks_real)>70: | ||||
| 
 | ||||
|             peaks_real = peaks_real[(peaks_real<width2) & (peaks_real>width1)] | ||||
| 
 | ||||
|             arg_max = np.argmax(sum_smoothed[peaks_real]) | ||||
| 
 | ||||
|             peaks_final = peaks_real[arg_max] | ||||
|              | ||||
|             return peaks_final | ||||
|         else: | ||||
|             return None | ||||
|      | ||||
|     def return_textlines_split_if_needed(self, textline_image): | ||||
| 
 | ||||
|         split_point = self.return_start_and_end_of_common_text_of_textline_ocr_without_common_section(textline_image) | ||||
|         if split_point: | ||||
|             image1 = textline_image[:, :split_point,:]# image.crop((0, 0, width2, height)) | ||||
|             image2 = textline_image[:, split_point:,:]#image.crop((width1, 0, width, height)) | ||||
|             return [image1, image2] | ||||
|         else: | ||||
|             return None | ||||
|      | ||||
|     def run(self): | ||||
|         ls_imgs = os.listdir(self.dir_in) | ||||
|          | ||||
|         if self.tr_ocr: | ||||
|             b_s = 2 | ||||
|             for ind_img in ls_imgs: | ||||
|                 t0 = time.time() | ||||
|                 file_name = ind_img.split('.')[0] | ||||
|                 dir_img = os.path.join(self.dir_in, ind_img) | ||||
|                 dir_xml = os.path.join(self.dir_xmls, file_name+'.xml') | ||||
|                 out_file_ocr = os.path.join(self.dir_out, file_name+'.xml') | ||||
|                 img = cv2.imread(dir_img) | ||||
| 
 | ||||
|                 ##file_name = Path(dir_xmls).stem | ||||
|                 tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding = 'iso-8859-5')) | ||||
|                 root1=tree1.getroot() | ||||
|                 alltags=[elem.tag for elem in root1.iter()] | ||||
|                 link=alltags[0].split('}')[0]+'}' | ||||
| 
 | ||||
|                 name_space = alltags[0].split('}')[0] | ||||
|                 name_space = name_space.split('{')[1] | ||||
| 
 | ||||
|                 region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')])  | ||||
|                          | ||||
|                      | ||||
|                      | ||||
|                 cropped_lines = [] | ||||
|                 cropped_lines_region_indexer = [] | ||||
|                 cropped_lines_meging_indexing = [] | ||||
| 
 | ||||
|                 indexer_text_region = 0 | ||||
|                 for nn in root1.iter(region_tags): | ||||
|                     for child_textregion in nn: | ||||
|                         if child_textregion.tag.endswith("TextLine"): | ||||
|                              | ||||
|                             for child_textlines in child_textregion: | ||||
|                                 if child_textlines.tag.endswith("Coords"): | ||||
|                                     cropped_lines_region_indexer.append(indexer_text_region) | ||||
|                                     p_h=child_textlines.attrib['points'].split(' ') | ||||
|                                     textline_coords =  np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ]  for x in p_h] ) | ||||
|                                     x,y,w,h = cv2.boundingRect(textline_coords) | ||||
|                                      | ||||
|                                     h2w_ratio = h/float(w) | ||||
|                                      | ||||
|                                     img_poly_on_img = np.copy(img) | ||||
|                                     mask_poly = np.zeros(img.shape) | ||||
|                                     mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) | ||||
|                                      | ||||
|                                     mask_poly = mask_poly[y:y+h, x:x+w, :] | ||||
|                                     img_crop = img_poly_on_img[y:y+h, x:x+w, :] | ||||
|                                     img_crop[mask_poly==0] = 255 | ||||
|                                      | ||||
|                                     if h2w_ratio > 0.05: | ||||
|                                         cropped_lines.append(img_crop) | ||||
|                                         cropped_lines_meging_indexing.append(0) | ||||
|                                     else: | ||||
|                                         splited_images = self.return_textlines_split_if_needed(img_crop) | ||||
|                                         #print(splited_images) | ||||
|                                         if splited_images: | ||||
|                                             cropped_lines.append(splited_images[0]) | ||||
|                                             cropped_lines_meging_indexing.append(1) | ||||
|                                             cropped_lines.append(splited_images[1]) | ||||
|                                             cropped_lines_meging_indexing.append(-1) | ||||
|                                         else: | ||||
|                                             cropped_lines.append(img_crop) | ||||
|                                             cropped_lines_meging_indexing.append(0) | ||||
|                     indexer_text_region = indexer_text_region +1 | ||||
|          | ||||
|          | ||||
|                 extracted_texts = [] | ||||
|                 n_iterations  = math.ceil(len(cropped_lines) / b_s)  | ||||
| 
 | ||||
|                 for i in range(n_iterations): | ||||
|                     if i==(n_iterations-1): | ||||
|                         n_start = i*b_s | ||||
|                         imgs = cropped_lines[n_start:] | ||||
|                     else: | ||||
|                         n_start = i*b_s | ||||
|                         n_end = (i+1)*b_s | ||||
|                         imgs = cropped_lines[n_start:n_end] | ||||
|                     pixel_values_merged = self.processor(imgs, return_tensors="pt").pixel_values | ||||
|                     generated_ids_merged = self.model_ocr.generate(pixel_values_merged.to(self.device)) | ||||
|                     generated_text_merged = self.processor.batch_decode(generated_ids_merged, skip_special_tokens=True) | ||||
|                      | ||||
|                     extracted_texts = extracted_texts + generated_text_merged | ||||
| 
 | ||||
|                 extracted_texts_merged = [extracted_texts[ind]  if cropped_lines_meging_indexing[ind]==0 else extracted_texts[ind]+extracted_texts[ind+1] if cropped_lines_meging_indexing[ind]==1 else None for ind in range(len(cropped_lines_meging_indexing))] | ||||
| 
 | ||||
|                 extracted_texts_merged = [ind for ind in extracted_texts_merged if ind is not None] | ||||
|                 #print(extracted_texts_merged, len(extracted_texts_merged)) | ||||
| 
 | ||||
|                 unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) | ||||
| 
 | ||||
|                 #print(len(unique_cropped_lines_region_indexer), 'unique_cropped_lines_region_indexer') | ||||
|                 text_by_textregion = [] | ||||
|                 for ind in unique_cropped_lines_region_indexer: | ||||
|                     extracted_texts_merged_un = np.array(extracted_texts_merged)[np.array(cropped_lines_region_indexer)==ind] | ||||
|                      | ||||
|                     text_by_textregion.append(" ".join(extracted_texts_merged_un)) | ||||
|                      | ||||
|                 #print(len(text_by_textregion) , indexer_text_region, "text_by_textregion") | ||||
| 
 | ||||
| 
 | ||||
|                 #print(time.time() - t0 ,'elapsed time') | ||||
| 
 | ||||
| 
 | ||||
|                 indexer = 0 | ||||
|                 indexer_textregion = 0 | ||||
|                 for nn in root1.iter(region_tags): | ||||
|                     text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') | ||||
|                     unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') | ||||
| 
 | ||||
|                      | ||||
|                     has_textline = False | ||||
|                     for child_textregion in nn: | ||||
|                         if child_textregion.tag.endswith("TextLine"): | ||||
|                             text_subelement = ET.SubElement(child_textregion, 'TextEquiv') | ||||
|                             unicode_textline = ET.SubElement(text_subelement, 'Unicode') | ||||
|                             unicode_textline.text = extracted_texts_merged[indexer] | ||||
|                             indexer = indexer + 1 | ||||
|                             has_textline = True | ||||
|                     if has_textline: | ||||
|                         unicode_textregion.text = text_by_textregion[indexer_textregion] | ||||
|                         indexer_textregion = indexer_textregion + 1 | ||||
|                          | ||||
| 
 | ||||
| 
 | ||||
|                 ET.register_namespace("",name_space) | ||||
|                 tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) | ||||
|                 #print("Job done in %.1fs", time.time() - t0) | ||||
|         else: | ||||
|             max_len = 512 | ||||
|             padding_token = 299 | ||||
|             image_width = max_len * 4 | ||||
|             image_height = 32 | ||||
|             b_s = 8 | ||||
| 
 | ||||
| 
 | ||||
|             img_size=(image_width, image_height) | ||||
|              | ||||
|             for ind_img in ls_imgs: | ||||
|                 t0 = time.time() | ||||
|                 file_name = ind_img.split('.')[0] | ||||
|                 dir_img = os.path.join(self.dir_in, ind_img) | ||||
|                 dir_xml = os.path.join(self.dir_xmls, file_name+'.xml') | ||||
|                 out_file_ocr = os.path.join(self.dir_out, file_name+'.xml') | ||||
|                 img = cv2.imread(dir_img) | ||||
| 
 | ||||
|                 tree1 = ET.parse(dir_xml, parser = ET.XMLParser(encoding = 'iso-8859-5')) | ||||
|                 root1=tree1.getroot() | ||||
|                 alltags=[elem.tag for elem in root1.iter()] | ||||
|                 link=alltags[0].split('}')[0]+'}' | ||||
| 
 | ||||
|                 name_space = alltags[0].split('}')[0] | ||||
|                 name_space = name_space.split('{')[1] | ||||
| 
 | ||||
|                 region_tags=np.unique([x for x in alltags if x.endswith('TextRegion')])  | ||||
|                      | ||||
|                 cropped_lines = [] | ||||
|                 cropped_lines_region_indexer = [] | ||||
|                 cropped_lines_meging_indexing = [] | ||||
|                  | ||||
|                 tinl = time.time() | ||||
|                 indexer_text_region = 0 | ||||
|                 for nn in root1.iter(region_tags): | ||||
|                     for child_textregion in nn: | ||||
|                         if child_textregion.tag.endswith("TextLine"): | ||||
|                              | ||||
|                             for child_textlines in child_textregion: | ||||
|                                 if child_textlines.tag.endswith("Coords"): | ||||
|                                     cropped_lines_region_indexer.append(indexer_text_region) | ||||
|                                     p_h=child_textlines.attrib['points'].split(' ') | ||||
|                                     textline_coords =  np.array( [ [ int(x.split(',')[0]) , int(x.split(',')[1]) ]  for x in p_h] ) | ||||
|                                     x,y,w,h = cv2.boundingRect(textline_coords) | ||||
|                                      | ||||
|                                     h2w_ratio = h/float(w) | ||||
|                                      | ||||
|                                     img_poly_on_img = np.copy(img) | ||||
|                                     mask_poly = np.zeros(img.shape) | ||||
|                                     mask_poly = cv2.fillPoly(mask_poly, pts=[textline_coords], color=(1, 1, 1)) | ||||
|                                      | ||||
|                                     mask_poly = mask_poly[y:y+h, x:x+w, :] | ||||
|                                     img_crop = img_poly_on_img[y:y+h, x:x+w, :] | ||||
|                                     img_crop[mask_poly==0] = 255 | ||||
|                                     img_crop = tf.reverse(img_crop,axis=[-1]) | ||||
|                                     img_crop = self.distortion_free_resize(img_crop, img_size) | ||||
|                                     img_crop = tf.cast(img_crop, tf.float32) / 255.0 | ||||
|                                     cropped_lines.append(img_crop) | ||||
| 
 | ||||
|                     indexer_text_region = indexer_text_region +1 | ||||
|                      | ||||
|                  | ||||
|                 extracted_texts = [] | ||||
| 
 | ||||
|                 n_iterations  = math.ceil(len(cropped_lines) / b_s)  | ||||
| 
 | ||||
|                 for i in range(n_iterations): | ||||
|                     if i==(n_iterations-1): | ||||
|                         n_start = i*b_s | ||||
|                         imgs = cropped_lines[n_start:] | ||||
|                         imgs = np.array(imgs) | ||||
|                         imgs = imgs.reshape(imgs.shape[0], image_width, image_height, 3) | ||||
|                     else: | ||||
|                         n_start = i*b_s | ||||
|                         n_end = (i+1)*b_s | ||||
|                         imgs = cropped_lines[n_start:n_end] | ||||
|                         imgs = np.array(imgs).reshape(b_s, image_width, image_height, 3) | ||||
|                          | ||||
| 
 | ||||
|                     preds = self.prediction_model.predict(imgs, verbose=0) | ||||
|                     pred_texts = self.decode_batch_predictions(preds) | ||||
| 
 | ||||
|                     for ib in range(imgs.shape[0]): | ||||
|                         pred_texts_ib = pred_texts[ib].strip("[UNK]") | ||||
|                         extracted_texts.append(pred_texts_ib) | ||||
|                  | ||||
|                 unique_cropped_lines_region_indexer = np.unique(cropped_lines_region_indexer) | ||||
|                  | ||||
|                 text_by_textregion = [] | ||||
|                 for ind in unique_cropped_lines_region_indexer: | ||||
|                     extracted_texts_merged_un = np.array(extracted_texts)[np.array(cropped_lines_region_indexer)==ind] | ||||
|                      | ||||
|                     text_by_textregion.append(" ".join(extracted_texts_merged_un)) | ||||
|                      | ||||
|                 indexer = 0 | ||||
|                 indexer_textregion = 0 | ||||
|                 for nn in root1.iter(region_tags): | ||||
|                     text_subelement_textregion = ET.SubElement(nn, 'TextEquiv') | ||||
|                     unicode_textregion = ET.SubElement(text_subelement_textregion, 'Unicode') | ||||
| 
 | ||||
|                      | ||||
|                     has_textline = False | ||||
|                     for child_textregion in nn: | ||||
|                         if child_textregion.tag.endswith("TextLine"): | ||||
|                             text_subelement = ET.SubElement(child_textregion, 'TextEquiv') | ||||
|                             unicode_textline = ET.SubElement(text_subelement, 'Unicode') | ||||
|                             unicode_textline.text = extracted_texts[indexer] | ||||
|                             indexer = indexer + 1 | ||||
|                             has_textline = True | ||||
|                     if has_textline: | ||||
|                         unicode_textregion.text = text_by_textregion[indexer_textregion] | ||||
|                         indexer_textregion = indexer_textregion + 1 | ||||
| 
 | ||||
|                 ET.register_namespace("",name_space) | ||||
|                 tree1.write(out_file_ocr,xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) | ||||
|                 #print("Job done in %.1fs", time.time() - t0) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue