training/inference.py: add typing info, organize imports

This commit is contained in:
kba 2025-10-17 14:07:43 +02:00
parent 3a73ccca2e
commit af74890b2e

View file

@ -1,14 +1,15 @@
import sys import sys
import os import os
from typing import Tuple
import warnings import warnings
import json import json
import numpy as np import numpy as np
import cv2 import cv2
from tensorflow.keras.models import load_model from numpy._typing import NDArray
import tensorflow as tf import tensorflow as tf
from tensorflow.keras import backend as K from keras.models import Model, load_model
from tensorflow.keras.layers import * from keras import backend as K
import click import click
from tensorflow.python.keras import backend as tensorflow_backend from tensorflow.python.keras import backend as tensorflow_backend
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@ -34,6 +35,7 @@ Tool to load model and predict for given image.
""" """
class sbb_predict: class sbb_predict:
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area): def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
self.image=image self.image=image
self.dir_in=dir_in self.dir_in=dir_in
@ -77,7 +79,7 @@ class sbb_predict:
#print(img[:,:,0].min()) #print(img[:,:,0].min())
#blur = cv2.GaussianBlur(img,(5,5)) #blur = cv2.GaussianBlur(img,(5,5))
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU) #ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) _, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
@ -116,19 +118,19 @@ class sbb_predict:
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes) denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
def weighted_categorical_crossentropy(self,weights=None): # def weighted_categorical_crossentropy(self,weights=None):
#
def loss(y_true, y_pred): # def loss(y_true, y_pred):
labels_floats = tf.cast(y_true, tf.float32) # labels_floats = tf.cast(y_true, tf.float32)
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred) # per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
#
if weights is not None: # if weights is not None:
weight_mask = tf.maximum(tf.reduce_max(tf.constant( # weight_mask = tf.maximum(tf.reduce_max(tf.constant(
np.array(weights, dtype=np.float32)[None, None, None]) # np.array(weights, dtype=np.float32)[None, None, None])
* labels_floats, axis=-1), 1.0) # * labels_floats, axis=-1), 1.0)
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None] # per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
return tf.reduce_mean(per_pixel_loss) # return tf.reduce_mean(per_pixel_loss)
return self.loss # return self.loss
def IoU(self,Yi,y_predi): def IoU(self,Yi,y_predi):
@ -177,12 +179,13 @@ class sbb_predict:
##if self.weights_dir!=None: ##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir) ##self.model.load_weights(self.weights_dir)
assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order': if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1] self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2] self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3] self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
def visualize_model_output(self, prediction, img, task): def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization": if task == "binarization":
prediction = prediction * -1 prediction = prediction * -1
prediction = prediction + 1 prediction = prediction + 1
@ -226,9 +229,12 @@ class sbb_predict:
added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0) added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)
assert isinstance(added_image, np.ndarray)
assert isinstance(layout_only, np.ndarray)
return added_image, layout_only return added_image, layout_only
def predict(self, image_dir): def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification': if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name'] classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(image_dir, 0) img_1ch = img=cv2.imread(image_dir, 0)
@ -240,7 +246,7 @@ class sbb_predict:
img_in[0, :, :, 1] = img_1ch[:, :] img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :] img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model.predict(img_in, verbose=0) label_p_pred = self.model.predict(img_in, verbose='0')
index_class = np.argmax(label_p_pred[0]) index_class = np.argmax(label_p_pred[0])
print("Predicted Class: {}".format(classes_names[str(int(index_class))])) print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
@ -361,7 +367,7 @@ class sbb_predict:
#input_1[:,:,1] = img3[:,:,0]/5. #input_1[:,:,1] = img3[:,:,0]/5.
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs): if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
y_pr = self.model.predict(input_1 , verbose=0) y_pr = self.model.predict(input_1 , verbose='0')
scalibility_num = scalibility_num+1 scalibility_num = scalibility_num+1
if batch_counter==inference_bs: if batch_counter==inference_bs:
@ -395,6 +401,7 @@ class sbb_predict:
name_space = name_space.split('{')[1] name_space = name_space.split('{')[1]
page_element = root_xml.find(link+'Page') page_element = root_xml.find(link+'Page')
assert isinstance(page_element, ET.Element)
""" """
ro_subelement = ET.SubElement(page_element, 'ReadingOrder') ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
@ -489,7 +496,7 @@ class sbb_predict:
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]), label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
verbose=0) verbose='0')
if self.task == 'enhancement': if self.task == 'enhancement':
seg = label_p_pred[0, :, :, :] seg = label_p_pred[0, :, :, :]
@ -497,6 +504,8 @@ class sbb_predict:
elif self.task == 'segmentation' or self.task == 'binarization': elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0] seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
else:
raise ValueError(f"Unhandled task {self.task}")
if i == 0 and j == 0: if i == 0 and j == 0:
@ -551,6 +560,8 @@ class sbb_predict:
elif self.task == 'segmentation' or self.task == 'binarization': elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0] seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2) seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
else:
raise ValueError(f"Unhandled task {self.task}")
prediction_true = seg.astype(int) prediction_true = seg.astype(int)