diff --git a/src/eynollah/training/inference.py b/src/eynollah/training/inference.py index 3fa8fd6..10fca6c 100644 --- a/src/eynollah/training/inference.py +++ b/src/eynollah/training/inference.py @@ -1,14 +1,15 @@ import sys import os +from typing import Tuple import warnings import json import numpy as np import cv2 -from tensorflow.keras.models import load_model +from numpy._typing import NDArray import tensorflow as tf -from tensorflow.keras import backend as K -from tensorflow.keras.layers import * +from keras.models import Model, load_model +from keras import backend as K import click from tensorflow.python.keras import backend as tensorflow_backend import xml.etree.ElementTree as ET @@ -34,6 +35,7 @@ Tool to load model and predict for given image. """ class sbb_predict: + def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area): self.image=image self.dir_in=dir_in @@ -77,7 +79,7 @@ class sbb_predict: #print(img[:,:,0].min()) #blur = cv2.GaussianBlur(img,(5,5)) #ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) - retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) + _, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) @@ -116,19 +118,19 @@ class sbb_predict: denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch - def weighted_categorical_crossentropy(self,weights=None): - - def loss(y_true, y_pred): - labels_floats = tf.cast(y_true, tf.float32) - per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) - - if weights is not None: - weight_mask = tf.maximum(tf.reduce_max(tf.constant( - np.array(weights, dtype=np.float32)[None, None, None]) - * labels_floats, axis=-1), 1.0) - per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] - return tf.reduce_mean(per_pixel_loss) - return self.loss + # def weighted_categorical_crossentropy(self,weights=None): + # + # def loss(y_true, y_pred): + # labels_floats = tf.cast(y_true, tf.float32) + # per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) + # + # if weights is not None: + # weight_mask = tf.maximum(tf.reduce_max(tf.constant( + # np.array(weights, dtype=np.float32)[None, None, None]) + # * labels_floats, axis=-1), 1.0) + # per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] + # return tf.reduce_mean(per_pixel_loss) + # return self.loss def IoU(self,Yi,y_predi): @@ -177,12 +179,13 @@ class sbb_predict: ##if self.weights_dir!=None: ##self.model.load_weights(self.weights_dir) + assert isinstance(self.model, Model) if self.task != 'classification' and self.task != 'reading_order': self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] - def visualize_model_output(self, prediction, img, task): + def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]: if task == "binarization": prediction = prediction * -1 prediction = prediction + 1 @@ -226,9 +229,12 @@ class sbb_predict: added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) + assert isinstance(added_image, np.ndarray) + assert isinstance(layout_only, np.ndarray) return added_image, layout_only def predict(self, image_dir): + assert isinstance(self.model, Model) if self.task == 'classification': classes_names = self.config_params_model['classification_classes_name'] img_1ch = img=cv2.imread(image_dir, 0) @@ -240,7 +246,7 @@ class sbb_predict: img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :] - label_p_pred = self.model.predict(img_in, verbose=0) + label_p_pred = self.model.predict(img_in, verbose='0') index_class = np.argmax(label_p_pred[0]) print("Predicted Class: {}".format(classes_names[str(int(index_class))])) @@ -361,7 +367,7 @@ class sbb_predict: #input_1[:,:,1] = img3[:,:,0]/5. if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs): - y_pr = self.model.predict(input_1 , verbose=0) + y_pr = self.model.predict(input_1 , verbose='0') scalibility_num = scalibility_num+1 if batch_counter==inference_bs: @@ -395,6 +401,7 @@ class sbb_predict: name_space = name_space.split('{')[1] page_element = root_xml.find(link+'Page') + assert isinstance(page_element, ET.Element) """ ro_subelement = ET.SubElement(page_element, 'ReadingOrder') @@ -489,7 +496,7 @@ class sbb_predict: img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]), - verbose=0) + verbose='0') if self.task == 'enhancement': seg = label_p_pred[0, :, :, :] @@ -497,6 +504,8 @@ class sbb_predict: elif self.task == 'segmentation' or self.task == 'binarization': seg = np.argmax(label_p_pred, axis=3)[0] seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + else: + raise ValueError(f"Unhandled task {self.task}") if i == 0 and j == 0: @@ -551,6 +560,8 @@ class sbb_predict: elif self.task == 'segmentation' or self.task == 'binarization': seg = np.argmax(label_p_pred, axis=3)[0] seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) + else: + raise ValueError(f"Unhandled task {self.task}") prediction_true = seg.astype(int)