rfct: move all tensorflow/keras imports and hacks to utils.tf

refactoring-2024-08-merged
kba 4 months ago
parent b15b1bdcd5
commit 8c4bfa229f

@ -8,11 +8,10 @@ document layout analysis (segmentation) with output in PAGE-XML
from logging import Logger
import math
import os
import sys
from os import listdir
from os.path import join
import time
from typing import Optional
import warnings
from pathlib import Path
from multiprocessing import Process, Queue, cpu_count
from PIL.Image import Image
@ -23,19 +22,7 @@ from scipy.signal import find_peaks
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from qurator.eynollah.utils.keras import PatchEncoder, Patches
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
stderr = sys.stderr
sys.stderr = open(os.devnull, "w")
import tensorflow as tf
sys.stderr = stderr
tf.get_logger().setLevel("ERROR")
warnings.filterwarnings("ignore")
load_model = tf.keras.models.load_model
# use tf1 compatibility for keras backend
set_session = tf.compat.v1.keras.backend.set_session
from qurator.eynollah.utils.tf import tf, PatchEncoder, Patches
from .utils.contour import (
filter_contours_area_of_image,
@ -182,7 +169,7 @@ class Eynollah():
self.model_textline_dir = dir_models + "/eynollah-textline_20210425"
self.model_tables = dir_models + "/eynollah-tables_20210319"
self.models = {}
self.models : dict[str, tf.keras.Model] = {}
if dir_in and light_version:
config = tf.compat.v1.ConfigProto()
@ -198,7 +185,7 @@ class Eynollah():
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.ls_imgs = os.listdir(self.dir_in)
self.ls_imgs = listdir(self.dir_in)
if dir_in and not light_version:
config = tf.compat.v1.ConfigProto()
@ -216,7 +203,7 @@ class Eynollah():
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
self.ls_imgs = os.listdir(self.dir_in)
self.ls_imgs = listdir(self.dir_in)
def _cache_images(self, image_filename=None, image_pil=None):
@ -586,9 +573,8 @@ class Eynollah():
self.writer.scale_x = self.scale_x
self.writer.height_org = self.height_org
self.writer.width_org = self.width_org
def start_new_session_and_model(self, model_dir):
def start_new_session_and_model(self, model_dir) -> tf.keras.Model:
self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
#gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
#gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True)
@ -613,8 +599,7 @@ class Eynollah():
model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
self.models[model_dir] = model
# FIXME: why?
return model, None
return model
def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1):
self.logger.debug("enter do_prediction")
@ -910,7 +895,7 @@ class Eynollah():
img = cv2.GaussianBlur(self.image, (5, 5), 0)
if not self.dir_in:
model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
model_page = self.start_new_session_and_model(self.model_page_dir)
if not self.dir_in:
img_page_prediction = self.do_prediction(False, img, model_page)
@ -958,7 +943,7 @@ class Eynollah():
else:
img = self.imread()
if not self.dir_in:
model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
model_page = self.start_new_session_and_model(self.model_page_dir)
img = cv2.GaussianBlur(img, (5, 5), 0)
if self.dir_in:
@ -2774,7 +2759,7 @@ class Eynollah():
for img_name in self.ls_imgs:
t0 = time.time()
if self.dir_in:
self.reset_file_name_dir(os.path.join(self.dir_in,img_name))
self.reset_file_name_dir(join(self.dir_in,img_name))
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(self.light_version)
self.logger.info("Enhancing took %.1fs ", time.time() - t0)

@ -1,10 +1,37 @@
import warnings
from os import environ, devnull
import sys
environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
stderr = sys.stderr
sys.stderr = open(devnull, "w")
import tensorflow as tf
sys.stderr = stderr
tf.get_logger().setLevel("ERROR")
warnings.filterwarnings("ignore")
layers = tf.keras.layers
__all__ = [
'PatchEncoder',
'Patches',
'load_model',
'set_session',
'tf',
]
PROJECTION_DIM = 64
PATCH_SIZE = 1
NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28
def load_model(*args, **kwargs) -> tf.keras.Model:
ret = tf.keras.models.load_model(*args, **kwargs)
assert isinstance(ret, tf.keras.Model)
return ret
# use tf1 compatibility for keras backend
set_session = tf.compat.v1.keras.backend.set_session
class Patches(layers.Layer):
def __init__(self, *args, **kwargs):
@ -55,4 +82,3 @@ class PatchEncoder(layers.Layer):
})
return config
Loading…
Cancel
Save