mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-10-26 23:34:13 +01:00
training/inference.py: add typing info, organize imports
This commit is contained in:
parent
3a73ccca2e
commit
af74890b2e
1 changed files with 32 additions and 21 deletions
|
|
@ -1,14 +1,15 @@
|
|||
import sys
|
||||
import os
|
||||
from typing import Tuple
|
||||
import warnings
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from tensorflow.keras.models import load_model
|
||||
from numpy._typing import NDArray
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import backend as K
|
||||
from tensorflow.keras.layers import *
|
||||
from keras.models import Model, load_model
|
||||
from keras import backend as K
|
||||
import click
|
||||
from tensorflow.python.keras import backend as tensorflow_backend
|
||||
import xml.etree.ElementTree as ET
|
||||
|
|
@ -34,6 +35,7 @@ Tool to load model and predict for given image.
|
|||
"""
|
||||
|
||||
class sbb_predict:
|
||||
|
||||
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
|
||||
self.image=image
|
||||
self.dir_in=dir_in
|
||||
|
|
@ -77,7 +79,7 @@ class sbb_predict:
|
|||
#print(img[:,:,0].min())
|
||||
#blur = cv2.GaussianBlur(img,(5,5))
|
||||
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
_, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
|
||||
|
||||
|
||||
|
||||
|
|
@ -116,19 +118,19 @@ class sbb_predict:
|
|||
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
|
||||
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
|
||||
|
||||
def weighted_categorical_crossentropy(self,weights=None):
|
||||
|
||||
def loss(y_true, y_pred):
|
||||
labels_floats = tf.cast(y_true, tf.float32)
|
||||
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
|
||||
|
||||
if weights is not None:
|
||||
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
||||
np.array(weights, dtype=np.float32)[None, None, None])
|
||||
* labels_floats, axis=-1), 1.0)
|
||||
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||
return tf.reduce_mean(per_pixel_loss)
|
||||
return self.loss
|
||||
# def weighted_categorical_crossentropy(self,weights=None):
|
||||
#
|
||||
# def loss(y_true, y_pred):
|
||||
# labels_floats = tf.cast(y_true, tf.float32)
|
||||
# per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
|
||||
#
|
||||
# if weights is not None:
|
||||
# weight_mask = tf.maximum(tf.reduce_max(tf.constant(
|
||||
# np.array(weights, dtype=np.float32)[None, None, None])
|
||||
# * labels_floats, axis=-1), 1.0)
|
||||
# per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
|
||||
# return tf.reduce_mean(per_pixel_loss)
|
||||
# return self.loss
|
||||
|
||||
|
||||
def IoU(self,Yi,y_predi):
|
||||
|
|
@ -177,12 +179,13 @@ class sbb_predict:
|
|||
##if self.weights_dir!=None:
|
||||
##self.model.load_weights(self.weights_dir)
|
||||
|
||||
assert isinstance(self.model, Model)
|
||||
if self.task != 'classification' and self.task != 'reading_order':
|
||||
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
|
||||
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
|
||||
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
|
||||
|
||||
def visualize_model_output(self, prediction, img, task):
|
||||
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
|
||||
if task == "binarization":
|
||||
prediction = prediction * -1
|
||||
prediction = prediction + 1
|
||||
|
|
@ -226,9 +229,12 @@ class sbb_predict:
|
|||
|
||||
added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)
|
||||
|
||||
assert isinstance(added_image, np.ndarray)
|
||||
assert isinstance(layout_only, np.ndarray)
|
||||
return added_image, layout_only
|
||||
|
||||
def predict(self, image_dir):
|
||||
assert isinstance(self.model, Model)
|
||||
if self.task == 'classification':
|
||||
classes_names = self.config_params_model['classification_classes_name']
|
||||
img_1ch = img=cv2.imread(image_dir, 0)
|
||||
|
|
@ -240,7 +246,7 @@ class sbb_predict:
|
|||
img_in[0, :, :, 1] = img_1ch[:, :]
|
||||
img_in[0, :, :, 2] = img_1ch[:, :]
|
||||
|
||||
label_p_pred = self.model.predict(img_in, verbose=0)
|
||||
label_p_pred = self.model.predict(img_in, verbose='0')
|
||||
index_class = np.argmax(label_p_pred[0])
|
||||
|
||||
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
|
||||
|
|
@ -361,7 +367,7 @@ class sbb_predict:
|
|||
#input_1[:,:,1] = img3[:,:,0]/5.
|
||||
|
||||
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
|
||||
y_pr = self.model.predict(input_1 , verbose=0)
|
||||
y_pr = self.model.predict(input_1 , verbose='0')
|
||||
scalibility_num = scalibility_num+1
|
||||
|
||||
if batch_counter==inference_bs:
|
||||
|
|
@ -395,6 +401,7 @@ class sbb_predict:
|
|||
name_space = name_space.split('{')[1]
|
||||
|
||||
page_element = root_xml.find(link+'Page')
|
||||
assert isinstance(page_element, ET.Element)
|
||||
|
||||
"""
|
||||
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
|
||||
|
|
@ -489,7 +496,7 @@ class sbb_predict:
|
|||
|
||||
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
|
||||
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
|
||||
verbose=0)
|
||||
verbose='0')
|
||||
|
||||
if self.task == 'enhancement':
|
||||
seg = label_p_pred[0, :, :, :]
|
||||
|
|
@ -497,6 +504,8 @@ class sbb_predict:
|
|||
elif self.task == 'segmentation' or self.task == 'binarization':
|
||||
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||
else:
|
||||
raise ValueError(f"Unhandled task {self.task}")
|
||||
|
||||
|
||||
if i == 0 and j == 0:
|
||||
|
|
@ -551,6 +560,8 @@ class sbb_predict:
|
|||
elif self.task == 'segmentation' or self.task == 'binarization':
|
||||
seg = np.argmax(label_p_pred, axis=3)[0]
|
||||
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
|
||||
else:
|
||||
raise ValueError(f"Unhandled task {self.task}")
|
||||
|
||||
prediction_true = seg.astype(int)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue