rfct: move all tensorflow/keras imports and hacks to utils.tf

refactoring-2024-08-merged
kba 4 months ago
parent b15b1bdcd5
commit 8c4bfa229f

@ -8,11 +8,10 @@ document layout analysis (segmentation) with output in PAGE-XML
from logging import Logger from logging import Logger
import math import math
import os from os import listdir
import sys from os.path import join
import time import time
from typing import Optional from typing import Optional
import warnings
from pathlib import Path from pathlib import Path
from multiprocessing import Process, Queue, cpu_count from multiprocessing import Process, Queue, cpu_count
from PIL.Image import Image from PIL.Image import Image
@ -23,19 +22,7 @@ from scipy.signal import find_peaks
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d from scipy.ndimage import gaussian_filter1d
from qurator.eynollah.utils.keras import PatchEncoder, Patches from qurator.eynollah.utils.tf import tf, PatchEncoder, Patches
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
stderr = sys.stderr
sys.stderr = open(os.devnull, "w")
import tensorflow as tf
sys.stderr = stderr
tf.get_logger().setLevel("ERROR")
warnings.filterwarnings("ignore")
load_model = tf.keras.models.load_model
# use tf1 compatibility for keras backend
set_session = tf.compat.v1.keras.backend.set_session
from .utils.contour import ( from .utils.contour import (
filter_contours_area_of_image, filter_contours_area_of_image,
@ -182,7 +169,7 @@ class Eynollah():
self.model_textline_dir = dir_models + "/eynollah-textline_20210425" self.model_textline_dir = dir_models + "/eynollah-textline_20210425"
self.model_tables = dir_models + "/eynollah-tables_20210319" self.model_tables = dir_models + "/eynollah-tables_20210319"
self.models = {} self.models : dict[str, tf.keras.Model] = {}
if dir_in and light_version: if dir_in and light_version:
config = tf.compat.v1.ConfigProto() config = tf.compat.v1.ConfigProto()
@ -198,7 +185,7 @@ class Eynollah():
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully) self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.ls_imgs = os.listdir(self.dir_in) self.ls_imgs = listdir(self.dir_in)
if dir_in and not light_version: if dir_in and not light_version:
config = tf.compat.v1.ConfigProto() config = tf.compat.v1.ConfigProto()
@ -216,7 +203,7 @@ class Eynollah():
self.model_region_fl = self.our_load_model(self.model_region_dir_fully) self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
self.ls_imgs = os.listdir(self.dir_in) self.ls_imgs = listdir(self.dir_in)
def _cache_images(self, image_filename=None, image_pil=None): def _cache_images(self, image_filename=None, image_pil=None):
@ -586,9 +573,8 @@ class Eynollah():
self.writer.scale_x = self.scale_x self.writer.scale_x = self.scale_x
self.writer.height_org = self.height_org self.writer.height_org = self.height_org
self.writer.width_org = self.width_org self.writer.width_org = self.width_org
def start_new_session_and_model(self, model_dir): def start_new_session_and_model(self, model_dir) -> tf.keras.Model:
self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir) self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
#gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) #gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
#gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True) #gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True)
@ -613,8 +599,7 @@ class Eynollah():
model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
self.models[model_dir] = model self.models[model_dir] = model
# FIXME: why? return model
return model, None
def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1): def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1):
self.logger.debug("enter do_prediction") self.logger.debug("enter do_prediction")
@ -910,7 +895,7 @@ class Eynollah():
img = cv2.GaussianBlur(self.image, (5, 5), 0) img = cv2.GaussianBlur(self.image, (5, 5), 0)
if not self.dir_in: if not self.dir_in:
model_page, session_page = self.start_new_session_and_model(self.model_page_dir) model_page = self.start_new_session_and_model(self.model_page_dir)
if not self.dir_in: if not self.dir_in:
img_page_prediction = self.do_prediction(False, img, model_page) img_page_prediction = self.do_prediction(False, img, model_page)
@ -958,7 +943,7 @@ class Eynollah():
else: else:
img = self.imread() img = self.imread()
if not self.dir_in: if not self.dir_in:
model_page, session_page = self.start_new_session_and_model(self.model_page_dir) model_page = self.start_new_session_and_model(self.model_page_dir)
img = cv2.GaussianBlur(img, (5, 5), 0) img = cv2.GaussianBlur(img, (5, 5), 0)
if self.dir_in: if self.dir_in:
@ -2774,7 +2759,7 @@ class Eynollah():
for img_name in self.ls_imgs: for img_name in self.ls_imgs:
t0 = time.time() t0 = time.time()
if self.dir_in: if self.dir_in:
self.reset_file_name_dir(os.path.join(self.dir_in,img_name)) self.reset_file_name_dir(join(self.dir_in,img_name))
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(self.light_version) img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(self.light_version)
self.logger.info("Enhancing took %.1fs ", time.time() - t0) self.logger.info("Enhancing took %.1fs ", time.time() - t0)

@ -1,10 +1,37 @@
import warnings
from os import environ, devnull
import sys
environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
stderr = sys.stderr
sys.stderr = open(devnull, "w")
import tensorflow as tf import tensorflow as tf
sys.stderr = stderr
tf.get_logger().setLevel("ERROR")
warnings.filterwarnings("ignore")
layers = tf.keras.layers layers = tf.keras.layers
__all__ = [
'PatchEncoder',
'Patches',
'load_model',
'set_session',
'tf',
]
PROJECTION_DIM = 64 PROJECTION_DIM = 64
PATCH_SIZE = 1 PATCH_SIZE = 1
NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28 NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28
def load_model(*args, **kwargs) -> tf.keras.Model:
ret = tf.keras.models.load_model(*args, **kwargs)
assert isinstance(ret, tf.keras.Model)
return ret
# use tf1 compatibility for keras backend
set_session = tf.compat.v1.keras.backend.set_session
class Patches(layers.Layer): class Patches(layers.Layer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -55,4 +82,3 @@ class PatchEncoder(layers.Layer):
}) })
return config return config
Loading…
Cancel
Save