import sys import os import numpy as np import warnings import cv2 import seaborn as sns from tensorflow.keras.models import load_model import tensorflow as tf from tensorflow.keras import backend as K from tensorflow.keras import layers import tensorflow.keras.losses from tensorflow.keras.layers import * from models import * from gt_gen_utils import * import click import json from tensorflow.python.keras import backend as tensorflow_backend import xml.etree.ElementTree as ET import matplotlib.pyplot as plt with warnings.catch_warnings(): warnings.simplefilter("ignore") __doc__=\ """ Tool to load model and predict for given image. """ class sbb_predict: def __init__(self,image, model, task, config_params_model, patches, save, ground_truth, xml_file, out): self.image=image self.patches=patches self.save=save self.model_dir=model self.ground_truth=ground_truth self.task=task self.config_params_model=config_params_model self.xml_file = xml_file self.out = out def resize_image(self,img_in,input_height,input_width): return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) def color_images(self,seg): ann_u=range(self.n_classes) if len(np.shape(seg))==3: seg=seg[:,:,0] seg_img=np.zeros((np.shape(seg)[0],np.shape(seg)[1],3)).astype(np.uint8) colors=sns.color_palette("hls", self.n_classes) for c in ann_u: c=int(c) segl=(seg==c) seg_img[:,:,0][seg==c]=c seg_img[:,:,1][seg==c]=c seg_img[:,:,2][seg==c]=c return seg_img def otsu_copy_binary(self,img): img_r=np.zeros((img.shape[0],img.shape[1],3)) img1=img[:,:,0] #print(img.min()) #print(img[:,:,0].min()) #blur = cv2.GaussianBlur(img,(5,5)) #ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) img_r[:,:,0]=threshold1 img_r[:,:,1]=threshold1 img_r[:,:,2]=threshold1 #img_r=img_r/float(np.max(img_r))*255 return img_r def otsu_copy(self,img): img_r=np.zeros((img.shape[0],img.shape[1],3)) #img1=img[:,:,0] #print(img.min()) #print(img[:,:,0].min()) #blur = cv2.GaussianBlur(img,(5,5)) #ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) _, threshold1 = cv2.threshold(img[:,:,0], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) _, threshold2 = cv2.threshold(img[:,:,1], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) _, threshold3 = cv2.threshold(img[:,:,2], 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) img_r[:,:,0]=threshold1 img_r[:,:,1]=threshold2 img_r[:,:,2]=threshold3 ###img_r=img_r/float(np.max(img_r))*255 return img_r def soft_dice_loss(self,y_true, y_pred, epsilon=1e-6): axes = tuple(range(1, len(y_pred.shape)-1)) numerator = 2. * K.sum(y_pred * y_true, axes) denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch def weighted_categorical_crossentropy(self,weights=None): def loss(y_true, y_pred): labels_floats = tf.cast(y_true, tf.float32) per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) if weights is not None: weight_mask = tf.maximum(tf.reduce_max(tf.constant( np.array(weights, dtype=np.float32)[None, None, None]) * labels_floats, axis=-1), 1.0) per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] return tf.reduce_mean(per_pixel_loss) return self.loss def IoU(self,Yi,y_predi): ## mean Intersection over Union ## Mean IoU = TP/(FN + TP + FP) IoUs = [] Nclass = np.unique(Yi) for c in Nclass: TP = np.sum( (Yi == c)&(y_predi==c) ) FP = np.sum( (Yi != c)&(y_predi==c) ) FN = np.sum( (Yi == c)&(y_predi != c)) IoU = TP/float(TP + FP + FN) if self.n_classes>2: print("class {:02.0f}: #TP={:6.0f}, #FP={:6.0f}, #FN={:5.0f}, IoU={:4.3f}".format(c,TP,FP,FN,IoU)) IoUs.append(IoU) if self.n_classes>2: mIoU = np.mean(IoUs) print("_________________") print("Mean IoU: {:4.3f}".format(mIoU)) return mIoU elif self.n_classes==2: mIoU = IoUs[1] print("_________________") print("IoU: {:4.3f}".format(mIoU)) return mIoU def start_new_session_and_model(self): config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True session = tf.compat.v1.Session(config=config) # tf.InteractiveSession() tensorflow_backend.set_session(session) #tensorflow.keras.layers.custom_layer = PatchEncoder #tensorflow.keras.layers.custom_layer = Patches self.model = load_model(self.model_dir , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) #config = tf.ConfigProto() #config.gpu_options.allow_growth=True #self.session = tf.InteractiveSession() #keras.losses.custom_loss = self.weighted_categorical_crossentropy #self.model = load_model(self.model_dir , compile=False) ##if self.weights_dir!=None: ##self.model.load_weights(self.weights_dir) if (self.task != 'classification' and self.task != 'reading_order'): self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] def visualize_model_output(self, prediction, img, task): if task == "binarization": prediction = prediction * -1 prediction = prediction + 1 added_image = prediction * 255 else: unique_classes = np.unique(prediction[:,:,0]) rgb_colors = {'0' : [255, 255, 255], '1' : [255, 0, 0], '2' : [255, 125, 0], '3' : [255, 0, 125], '4' : [125, 125, 125], '5' : [125, 125, 0], '6' : [0, 125, 255], '7' : [0, 125, 0], '8' : [125, 125, 125], '9' : [0, 125, 255], '10' : [125, 0, 125], '11' : [0, 255, 0], '12' : [0, 0, 255], '13' : [0, 255, 255], '14' : [255, 125, 125], '15' : [255, 0, 255]} output = np.zeros(prediction.shape) for unq_class in unique_classes: rgb_class_unique = rgb_colors[str(int(unq_class))] output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] img = self.resize_image(img, output.shape[0], output.shape[1]) output = output.astype(np.int32) img = img.astype(np.int32) added_image = cv2.addWeighted(img,0.5,output,0.1,0) return added_image def predict(self): self.start_new_session_and_model() if self.task == 'classification': classes_names = self.config_params_model['classification_classes_name'] img_1ch = img=cv2.imread(self.image, 0) img_1ch = img_1ch / 255.0 img_1ch = cv2.resize(img_1ch, (self.config_params_model['input_height'], self.config_params_model['input_width']), interpolation=cv2.INTER_NEAREST) img_in = np.zeros((1, img_1ch.shape[0], img_1ch.shape[1], 3)) img_in[0, :, :, 0] = img_1ch[:, :] img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] label_p_pred = self.model.predict(img_in, verbose=0) index_class = np.argmax(label_p_pred[0]) print("Predicted Class: {}".format(classes_names[str(int(index_class))])) elif self.task == 'reading_order': img_height = self.config_params_model['input_height'] img_width = self.config_params_model['input_width'] tree_xml, root_xml, bb_coord_printspace, file_name, id_paragraph, id_header, co_text_paragraph, co_text_header, tot_region_ref, x_len, y_len, index_tot_regions, img_poly = read_xml(self.xml_file) _, cy_main, x_min_main, x_max_main, y_min_main, y_max_main, _ = find_new_features_of_contours(co_text_header) img_header_and_sep = np.zeros((y_len,x_len), dtype='uint8') for j in range(len(cy_main)): img_header_and_sep[int(y_max_main[j]):int(y_max_main[j])+12,int(x_min_main[j]):int(x_max_main[j]) ] = 1 co_text_all = co_text_paragraph + co_text_header id_all_text = id_paragraph + id_header ##texts_corr_order_index = [index_tot_regions[tot_region_ref.index(i)] for i in id_all_text ] ##texts_corr_order_index_int = [int(x) for x in texts_corr_order_index] texts_corr_order_index_int = list(np.array(range(len(co_text_all)))) min_area = 0 max_area = 1 ##co_text_all, texts_corr_order_index_int = filter_contours_area_of_image(img_poly, co_text_all, texts_corr_order_index_int, max_area, min_area) labels_con = np.zeros((y_len,x_len,len(co_text_all)),dtype='uint8') for i in range(len(co_text_all)): img_label = np.zeros((y_len,x_len,3),dtype='uint8') img_label=cv2.fillPoly(img_label, pts =[co_text_all[i]], color=(1,1,1)) labels_con[:,:,i] = img_label[:,:,0] if bb_coord_printspace: #bb_coord_printspace[x,y,w,h,_,_] x = bb_coord_printspace[0] y = bb_coord_printspace[1] w = bb_coord_printspace[2] h = bb_coord_printspace[3] labels_con = labels_con[y:y+h, x:x+w, :] img_poly = img_poly[y:y+h, x:x+w, :] img_header_and_sep = img_header_and_sep[y:y+h, x:x+w] img3= np.copy(img_poly) labels_con = resize_image(labels_con, img_height, img_width) img_header_and_sep = resize_image(img_header_and_sep, img_height, img_width) img3= resize_image (img3, img_height, img_width) img3 = img3.astype(np.uint16) inference_bs = 1#4 input_1= np.zeros( (inference_bs, img_height, img_width,3)) starting_list_of_regions = [] starting_list_of_regions.append( list(range(labels_con.shape[2])) ) index_update = 0 index_selected = starting_list_of_regions[0] scalibility_num = 0 while index_update>=0: ij_list = starting_list_of_regions[index_update] i = ij_list[0] ij_list.pop(0) pr_list = [] post_list = [] batch_counter = 0 tot_counter = 1 tot_iteration = len(ij_list) full_bs_ite= tot_iteration//inference_bs last_bs = tot_iteration % inference_bs jbatch_indexer =[] for j in ij_list: img1= np.repeat(labels_con[:,:,i][:, :, np.newaxis], 3, axis=2) img2 = np.repeat(labels_con[:,:,j][:, :, np.newaxis], 3, axis=2) img2[:,:,0][img3[:,:,0]==5] = 2 img2[:,:,0][img_header_and_sep[:,:]==1] = 3 img1[:,:,0][img3[:,:,0]==5] = 2 img1[:,:,0][img_header_and_sep[:,:]==1] = 3 #input_1= np.zeros( (height1, width1,3)) jbatch_indexer.append(j) input_1[batch_counter,:,:,0] = img1[:,:,0]/3. input_1[batch_counter,:,:,2] = img2[:,:,0]/3. input_1[batch_counter,:,:,1] = img3[:,:,0]/5. #input_1[batch_counter,:,:,:]= np.zeros( (batch_counter, height1, width1,3)) batch_counter = batch_counter+1 #input_1[:,:,0] = img1[:,:,0]/3. #input_1[:,:,2] = img2[:,:,0]/3. #input_1[:,:,1] = img3[:,:,0]/5. if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs): y_pr = self.model.predict(input_1 , verbose=0) scalibility_num = scalibility_num+1 if batch_counter==inference_bs: iteration_batches = inference_bs else: iteration_batches = last_bs for jb in range(iteration_batches): if y_pr[jb][0]>=0.5: post_list.append(jbatch_indexer[jb]) else: pr_list.append(jbatch_indexer[jb]) batch_counter = 0 jbatch_indexer = [] tot_counter = tot_counter+1 starting_list_of_regions, index_update = update_list_and_return_first_with_length_bigger_than_one(index_update, i, pr_list, post_list,starting_list_of_regions) index_sort = [i[0] for i in starting_list_of_regions ] id_all_text = np.array(id_all_text)[index_sort] alltags=[elem.tag for elem in root_xml.iter()] link=alltags[0].split('}')[0]+'}' name_space = alltags[0].split('}')[0] name_space = name_space.split('{')[1] page_element = root_xml.find(link+'Page') """ ro_subelement = ET.SubElement(page_element, 'ReadingOrder') #print(page_element, 'page_element') #new_element = ET.Element('ReadingOrder') new_element_element = ET.Element('OrderedGroup') new_element_element.set('id', "ro357564684568544579089") for index, id_text in enumerate(id_all_text): new_element_2 = ET.Element('RegionRefIndexed') new_element_2.set('regionRef', id_all_text[index]) new_element_2.set('index', str(index_sort[index])) new_element_element.append(new_element_2) ro_subelement.append(new_element_element) """ ##ro_subelement = ET.SubElement(page_element, 'ReadingOrder') ro_subelement = ET.Element('ReadingOrder') ro_subelement2 = ET.SubElement(ro_subelement, 'OrderedGroup') ro_subelement2.set('id', "ro357564684568544579089") for index, id_text in enumerate(id_all_text): new_element_2 = ET.SubElement(ro_subelement2, 'RegionRefIndexed') new_element_2.set('regionRef', id_all_text[index]) new_element_2.set('index', str(index)) if (link+'PrintSpace' in alltags) or (link+'Border' in alltags): page_element.insert(1, ro_subelement) else: page_element.insert(0, ro_subelement) alltags=[elem.tag for elem in root_xml.iter()] ET.register_namespace("",name_space) tree_xml.write(os.path.join(self.out, file_name+'.xml'),xml_declaration=True,method='xml',encoding="utf8",default_namespace=None) #tree_xml.write('library2.xml') else: if self.patches: #def textline_contours(img,input_width,input_height,n_classes,model): img=cv2.imread(self.image) self.img_org = np.copy(img) if img.shape[0] < self.img_height: img = cv2.resize(img, (img.shape[1], self.img_width), interpolation=cv2.INTER_NEAREST) if img.shape[1] < self.img_width: img = cv2.resize(img, (self.img_height, img.shape[0]), interpolation=cv2.INTER_NEAREST) margin = int(0 * self.img_width) width_mid = self.img_width - 2 * margin height_mid = self.img_height - 2 * margin img = img / float(255.0) img_h = img.shape[0] img_w = img.shape[1] prediction_true = np.zeros((img_h, img_w, 3)) nxf = img_w / float(width_mid) nyf = img_h / float(height_mid) nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf) nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf) for i in range(nxf): for j in range(nyf): if i == 0: index_x_d = i * width_mid index_x_u = index_x_d + self.img_width else: index_x_d = i * width_mid index_x_u = index_x_d + self.img_width if j == 0: index_y_d = j * height_mid index_y_u = index_y_d + self.img_height else: index_y_d = j * height_mid index_y_u = index_y_d + self.img_height if index_x_u > img_w: index_x_u = img_w index_x_d = img_w - self.img_width if index_y_u > img_h: index_y_u = img_h index_y_d = img_h - self.img_height img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]), verbose=0) if self.task == 'enhancement': seg = label_p_pred[0, :, :, :] seg = seg * 255 elif self.task == 'segmentation' or self.task == 'binarization': seg = np.argmax(label_p_pred, axis=3)[0] seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) if i == 0 and j == 0: seg = seg[0 : seg.shape[0] - margin, 0 : seg.shape[1] - margin] prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg elif i == nxf - 1 and j == nyf - 1: seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - 0] prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg elif i == 0 and j == nyf - 1: seg = seg[margin : seg.shape[0] - 0, 0 : seg.shape[1] - margin] prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg elif i == nxf - 1 and j == 0: seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - 0] prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg elif i == 0 and j != 0 and j != nyf - 1: seg = seg[margin : seg.shape[0] - margin, 0 : seg.shape[1] - margin] prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg elif i == nxf - 1 and j != 0 and j != nyf - 1: seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - 0] prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg elif i != 0 and i != nxf - 1 and j == 0: seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - margin] prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg elif i != 0 and i != nxf - 1 and j == nyf - 1: seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - margin] prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg else: seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - margin] prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg prediction_true = prediction_true.astype(int) prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST) return prediction_true else: img=cv2.imread(self.image) self.img_org = np.copy(img) width=self.img_width height=self.img_height img=img/255.0 img=self.resize_image(img,self.img_height,self.img_width) label_p_pred=self.model.predict( img.reshape(1,img.shape[0],img.shape[1],img.shape[2])) if self.task == 'enhancement': seg = label_p_pred[0, :, :, :] seg = seg * 255 elif self.task == 'segmentation' or self.task == 'binarization': seg = np.argmax(label_p_pred, axis=3)[0] seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) prediction_true = seg.astype(int) prediction_true = cv2.resize(prediction_true, (self.img_org.shape[1], self.img_org.shape[0]), interpolation=cv2.INTER_NEAREST) return prediction_true def run(self): res=self.predict() if (self.task == 'classification' or self.task == 'reading_order'): pass else: img_seg_overlayed = self.visualize_model_output(res, self.img_org, self.task) if self.save: cv2.imwrite(self.save,img_seg_overlayed) if self.ground_truth: gt_img=cv2.imread(self.ground_truth) self.IoU(gt_img[:,:,0],res[:,:,0]) @click.command() @click.option( "--image", "-i", help="image filename", type=click.Path(exists=True, dir_okay=False), ) @click.option( "--out", "-o", help="output directory where xml with detected reading order will be written.", type=click.Path(exists=True, file_okay=False), ) @click.option( "--patches/--no-patches", "-p/-nop", is_flag=True, help="if this parameter set to true, this tool will try to do inference in patches.", ) @click.option( "--save", "-s", help="save prediction as a png file in current folder.", ) @click.option( "--model", "-m", help="directory of models", type=click.Path(exists=True, file_okay=False), required=True, ) @click.option( "--ground_truth", "-gt", help="ground truth directory if you want to see the iou of prediction.", ) @click.option( "--xml_file", "-xml", help="xml file with layout coordinates that reading order detection will be implemented on. The result will be written in the same xml file.", ) def main(image, model, patches, save, ground_truth, xml_file, out): with open(os.path.join(model,'config.json')) as f: config_params_model = json.load(f) task = config_params_model['task'] if (task != 'classification' and task != 'reading_order'): if not save: print("Error: You used one of segmentation or binarization task but not set -s, you need a filename to save visualized output with -s") sys.exit(1) x=sbb_predict(image, model, task, config_params_model, patches, save, ground_truth, xml_file, out) x.run() if __name__=="__main__": main()