From f93c6c288d9525202957da5bb000202a657e6df8 Mon Sep 17 00:00:00 2001 From: vahidrezanezhad Date: Sat, 14 Dec 2024 02:50:17 +0100 Subject: [PATCH] function of patch-wise inference with scatter_nd is added --- src/eynollah/eynollah.py | 107 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 443b5e9..28cb330 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -1047,6 +1047,110 @@ class Eynollah: #label_scaled_padded[h_start:h_start+h_n, w_start:w_start+w_n,:] = label_res[:,:,:] return img_scaled_padded#, label_scaled_padded + def do_prediction_new_concept_scatter_nd(self, patches, img, model, n_batch_inference=1, marginal_of_patch_percent=0.1, thresholding_for_some_classes_in_light_version=False, thresholding_for_artificial_class_in_light_version=False): + self.logger.debug("enter do_prediction_new_concept") + + img_height_model = model.layers[-1].output_shape[1] + img_width_model = model.layers[-1].output_shape[2] + + if not patches: + img_h_page = img.shape[0] + img_w_page = img.shape[1] + img = img / 255.0 + img = resize_image(img, img_height_model, img_width_model) + + label_p_pred = model.predict(img[np.newaxis], verbose=0) + seg = np.argmax(label_p_pred, axis=3)[0] + + if thresholding_for_artificial_class_in_light_version: + #seg_text = label_p_pred[0,:,:,1] + #seg_text[seg_text<0.2] =0 + #seg_text[seg_text>0] =1 + #seg[seg_text==1]=1 + + seg_art = label_p_pred[0,:,:,4] + seg_art[seg_art<0.2] =0 + seg_art[seg_art>0] =1 + seg[seg_art==1]=4 + + + seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + prediction_true = resize_image(seg_color, img_h_page, img_w_page) + prediction_true = prediction_true.astype(np.uint8) + return prediction_true + + if img.shape[0] < img_height_model: + img = resize_image(img, img_height_model, img.shape[1]) + + if img.shape[1] < img_width_model: + img = resize_image(img, img.shape[0], img_width_model) + + self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model) + ##margin = int(marginal_of_patch_percent * img_height_model) + #width_mid = img_width_model - 2 * margin + #height_mid = img_height_model - 2 * margin + img = img / float(255.0) + + img = img.astype(np.float16) + img_h = img.shape[0] + img_w = img.shape[1] + + stride_x = img_width_model - 100 + stride_y = img_height_model - 100 + + one_tensor = tf.ones_like(img) + img_patches = tf.image.extract_patches(images=[img,one_tensor], + sizes=[1, img_height_model, img_width_model, 1], + strides=[1, stride_y, stride_x, 1], + rates=[1, 1, 1, 1], + padding='SAME') + + one_patches = img_patches[1] + img_patches = img_patches[0] + img_patches = tf.squeeze(img_patches) + + img_patches_resh = tf.reshape(img_patches, shape = (img_patches.shape[0]*img_patches.shape[1], img_height_model, img_width_model, 3)) + + pred_patches = model.predict(img_patches_resh, batch_size=n_batch_inference) + + one_patches = tf.squeeze(one_patches) + one_patches = tf.reshape(one_patches, [img_patches.shape[0]*img_patches.shape[1],img_height_model,img_width_model,3]) + + x = tf.range(img.shape[1]) + y = tf.range(img.shape[0]) + x, y = tf.meshgrid(x, y) + indices = tf.stack([y, x], axis=-1) + + indices_patches = tf.image.extract_patches(images=tf.expand_dims(indices, axis=0), sizes=[1, img_height_model, img_width_model, 1], strides=[1, stride_y, stride_x, 1], rates=[1, 1, 1, 1], padding='SAME') + indices_patches = tf.squeeze(indices_patches) + indices_patches = tf.reshape(indices_patches, [img_patches.shape[0]*img_patches.shape[1],img_height_model, img_width_model,2]) + + margin_y = int( (img_height_model - stride_y)/2. ) + margin_x = int( (img_width_model - stride_x)/2. ) + + mask_margin = np.zeros((img_height_model, img_width_model)) + + mask_margin[margin_y:img_height_model-margin_y, margin_x:img_width_model-margin_x] = 1 + + indices_patches_array = indices_patches.numpy() + + for i in range(indices_patches_array.shape[0]): + indices_patches_array[i,:,:,0] = indices_patches_array[i,:,:,0]*mask_margin + indices_patches_array[i,:,:,1] = indices_patches_array[i,:,:,1]*mask_margin + + reconstructed = tf.scatter_nd(indices=indices_patches_array, updates=pred_patches, shape=(img.shape[0],img.shape[1],pred_patches.shape[-1])) + reconstructed_argmax = reconstructed.numpy() + + prediction_true = np.argmax(reconstructed_argmax, axis=2) + prediction_true = prediction_true.astype(np.uint8) + + gc.collect() + return np.repeat(prediction_true[:, :, np.newaxis], 3, axis=2) + + + + + def do_prediction_new_concept(self, patches, img, model, n_batch_inference=1, marginal_of_patch_percent=0.1, thresholding_for_some_classes_in_light_version=False, thresholding_for_artificial_class_in_light_version=False): self.logger.debug("enter do_prediction_new_concept") @@ -4891,7 +4995,7 @@ class Eynollah: all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_h, slopes_marginals, cont_page, polygons_lines_xml, ocr_all_textlines) self.logger.info("Job done in %.1fs", time.time() - t0) - print("Job done in %.1fs", time.time() - t0) + #print("Job done in %.1fs", time.time() - t0) if self.dir_in: self.writer.write_pagexml(pcgts) continue @@ -4975,6 +5079,7 @@ class Eynollah: pcgts = self.writer.build_pagexml_no_full_layout(txt_con_org, page_coord, order_text_new, id_of_texts_tot, all_found_textline_polygons, all_box_coord, polygons_of_images, polygons_of_marginals, all_found_textline_polygons_marginals, all_box_coord_marginals, slopes, slopes_marginals, cont_page, polygons_lines_xml, contours_tables, ocr_all_textlines) + #print("Job done in %.1fs" % (time.time() - t0)) self.logger.info("Job done in %.1fs", time.time() - t0) if not self.dir_in: return pcgts