training/inference.py: add typing info, organize imports

This commit is contained in:
kba 2025-10-17 14:07:43 +02:00
parent 3a73ccca2e
commit af74890b2e

View file

@ -1,14 +1,15 @@
import sys
import os
from typing import Tuple
import warnings
import json
import numpy as np
import cv2
from tensorflow.keras.models import load_model
from numpy._typing import NDArray
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from keras.models import Model, load_model
from keras import backend as K
import click
from tensorflow.python.keras import backend as tensorflow_backend
import xml.etree.ElementTree as ET
@ -34,6 +35,7 @@ Tool to load model and predict for given image.
"""
class sbb_predict:
def __init__(self,image, dir_in, model, task, config_params_model, patches, save, save_layout, ground_truth, xml_file, out, min_area):
self.image=image
self.dir_in=dir_in
@ -77,7 +79,7 @@ class sbb_predict:
#print(img[:,:,0].min())
#blur = cv2.GaussianBlur(img,(5,5))
#ret3,th3 = cv2.threshold(blur,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
retval1, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
_, threshold1 = cv2.threshold(img1, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
@ -116,19 +118,19 @@ class sbb_predict:
denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
return 1.00 - K.mean(numerator / (denominator + epsilon)) # average over classes and batch
def weighted_categorical_crossentropy(self,weights=None):
def loss(y_true, y_pred):
labels_floats = tf.cast(y_true, tf.float32)
per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
if weights is not None:
weight_mask = tf.maximum(tf.reduce_max(tf.constant(
np.array(weights, dtype=np.float32)[None, None, None])
* labels_floats, axis=-1), 1.0)
per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
return tf.reduce_mean(per_pixel_loss)
return self.loss
# def weighted_categorical_crossentropy(self,weights=None):
#
# def loss(y_true, y_pred):
# labels_floats = tf.cast(y_true, tf.float32)
# per_pixel_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels_floats,logits=y_pred)
#
# if weights is not None:
# weight_mask = tf.maximum(tf.reduce_max(tf.constant(
# np.array(weights, dtype=np.float32)[None, None, None])
# * labels_floats, axis=-1), 1.0)
# per_pixel_loss = per_pixel_loss * weight_mask[:, :, :, None]
# return tf.reduce_mean(per_pixel_loss)
# return self.loss
def IoU(self,Yi,y_predi):
@ -177,12 +179,13 @@ class sbb_predict:
##if self.weights_dir!=None:
##self.model.load_weights(self.weights_dir)
assert isinstance(self.model, Model)
if self.task != 'classification' and self.task != 'reading_order':
self.img_height=self.model.layers[len(self.model.layers)-1].output_shape[1]
self.img_width=self.model.layers[len(self.model.layers)-1].output_shape[2]
self.n_classes=self.model.layers[len(self.model.layers)-1].output_shape[3]
def visualize_model_output(self, prediction, img, task):
def visualize_model_output(self, prediction, img, task) -> Tuple[NDArray, NDArray]:
if task == "binarization":
prediction = prediction * -1
prediction = prediction + 1
@ -226,9 +229,12 @@ class sbb_predict:
added_image = cv2.addWeighted(img,0.5,layout_only,0.1,0)
assert isinstance(added_image, np.ndarray)
assert isinstance(layout_only, np.ndarray)
return added_image, layout_only
def predict(self, image_dir):
assert isinstance(self.model, Model)
if self.task == 'classification':
classes_names = self.config_params_model['classification_classes_name']
img_1ch = img=cv2.imread(image_dir, 0)
@ -240,7 +246,7 @@ class sbb_predict:
img_in[0, :, :, 1] = img_1ch[:, :]
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = self.model.predict(img_in, verbose=0)
label_p_pred = self.model.predict(img_in, verbose='0')
index_class = np.argmax(label_p_pred[0])
print("Predicted Class: {}".format(classes_names[str(int(index_class))]))
@ -361,7 +367,7 @@ class sbb_predict:
#input_1[:,:,1] = img3[:,:,0]/5.
if batch_counter==inference_bs or ( (tot_counter//inference_bs)==full_bs_ite and tot_counter%inference_bs==last_bs):
y_pr = self.model.predict(input_1 , verbose=0)
y_pr = self.model.predict(input_1 , verbose='0')
scalibility_num = scalibility_num+1
if batch_counter==inference_bs:
@ -395,6 +401,7 @@ class sbb_predict:
name_space = name_space.split('{')[1]
page_element = root_xml.find(link+'Page')
assert isinstance(page_element, ET.Element)
"""
ro_subelement = ET.SubElement(page_element, 'ReadingOrder')
@ -489,7 +496,7 @@ class sbb_predict:
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = self.model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
verbose=0)
verbose='0')
if self.task == 'enhancement':
seg = label_p_pred[0, :, :, :]
@ -497,6 +504,8 @@ class sbb_predict:
elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
else:
raise ValueError(f"Unhandled task {self.task}")
if i == 0 and j == 0:
@ -551,6 +560,8 @@ class sbb_predict:
elif self.task == 'segmentation' or self.task == 'binarization':
seg = np.argmax(label_p_pred, axis=3)[0]
seg = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
else:
raise ValueError(f"Unhandled task {self.task}")
prediction_true = seg.astype(int)