diff --git a/qurator/eynollah/eynollah.py b/qurator/eynollah/eynollah.py index da3385f..fce73b5 100644 --- a/qurator/eynollah/eynollah.py +++ b/qurator/eynollah/eynollah.py @@ -8,11 +8,10 @@ document layout analysis (segmentation) with output in PAGE-XML from logging import Logger import math -import os -import sys +from os import listdir +from os.path import join import time from typing import Optional -import warnings from pathlib import Path from multiprocessing import Process, Queue, cpu_count from PIL.Image import Image @@ -23,19 +22,7 @@ from scipy.signal import find_peaks import matplotlib.pyplot as plt from scipy.ndimage import gaussian_filter1d -from qurator.eynollah.utils.keras import PatchEncoder, Patches - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -stderr = sys.stderr -sys.stderr = open(os.devnull, "w") -import tensorflow as tf -sys.stderr = stderr -tf.get_logger().setLevel("ERROR") -warnings.filterwarnings("ignore") - -load_model = tf.keras.models.load_model -# use tf1 compatibility for keras backend -set_session = tf.compat.v1.keras.backend.set_session +from qurator.eynollah.utils.tf import tf, PatchEncoder, Patches from .utils.contour import ( filter_contours_area_of_image, @@ -182,7 +169,7 @@ class Eynollah(): self.model_textline_dir = dir_models + "/eynollah-textline_20210425" self.model_tables = dir_models + "/eynollah-tables_20210319" - self.models = {} + self.models : dict[str, tf.keras.Model] = {} if dir_in and light_version: config = tf.compat.v1.ConfigProto() @@ -198,7 +185,7 @@ class Eynollah(): self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np) self.model_region_fl = self.our_load_model(self.model_region_dir_fully) - self.ls_imgs = os.listdir(self.dir_in) + self.ls_imgs = listdir(self.dir_in) if dir_in and not light_version: config = tf.compat.v1.ConfigProto() @@ -216,7 +203,7 @@ class Eynollah(): self.model_region_fl = self.our_load_model(self.model_region_dir_fully) self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement) - self.ls_imgs = os.listdir(self.dir_in) + self.ls_imgs = listdir(self.dir_in) def _cache_images(self, image_filename=None, image_pil=None): @@ -586,9 +573,8 @@ class Eynollah(): self.writer.scale_x = self.scale_x self.writer.height_org = self.height_org self.writer.width_org = self.width_org - - def start_new_session_and_model(self, model_dir): + def start_new_session_and_model(self, model_dir) -> tf.keras.Model: self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir) #gpu_options = tf.compat.v1.GPUOptions(allow_growth=True) #gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True) @@ -613,8 +599,7 @@ class Eynollah(): model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) self.models[model_dir] = model - # FIXME: why? - return model, None + return model def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1): self.logger.debug("enter do_prediction") @@ -910,7 +895,7 @@ class Eynollah(): img = cv2.GaussianBlur(self.image, (5, 5), 0) if not self.dir_in: - model_page, session_page = self.start_new_session_and_model(self.model_page_dir) + model_page = self.start_new_session_and_model(self.model_page_dir) if not self.dir_in: img_page_prediction = self.do_prediction(False, img, model_page) @@ -958,7 +943,7 @@ class Eynollah(): else: img = self.imread() if not self.dir_in: - model_page, session_page = self.start_new_session_and_model(self.model_page_dir) + model_page = self.start_new_session_and_model(self.model_page_dir) img = cv2.GaussianBlur(img, (5, 5), 0) if self.dir_in: @@ -2774,7 +2759,7 @@ class Eynollah(): for img_name in self.ls_imgs: t0 = time.time() if self.dir_in: - self.reset_file_name_dir(os.path.join(self.dir_in,img_name)) + self.reset_file_name_dir(join(self.dir_in,img_name)) img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(self.light_version) self.logger.info("Enhancing took %.1fs ", time.time() - t0) diff --git a/qurator/eynollah/utils/keras.py b/qurator/eynollah/utils/tf.py similarity index 75% rename from qurator/eynollah/utils/keras.py rename to qurator/eynollah/utils/tf.py index f5da4d2..546f05d 100644 --- a/qurator/eynollah/utils/keras.py +++ b/qurator/eynollah/utils/tf.py @@ -1,10 +1,37 @@ +import warnings +from os import environ, devnull +import sys + +environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +stderr = sys.stderr +sys.stderr = open(devnull, "w") import tensorflow as tf +sys.stderr = stderr +tf.get_logger().setLevel("ERROR") +warnings.filterwarnings("ignore") + layers = tf.keras.layers +__all__ = [ + 'PatchEncoder', + 'Patches', + 'load_model', + 'set_session', + 'tf', +] + PROJECTION_DIM = 64 PATCH_SIZE = 1 NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28 +def load_model(*args, **kwargs) -> tf.keras.Model: + ret = tf.keras.models.load_model(*args, **kwargs) + assert isinstance(ret, tf.keras.Model) + return ret + +# use tf1 compatibility for keras backend +set_session = tf.compat.v1.keras.backend.set_session + class Patches(layers.Layer): def __init__(self, *args, **kwargs): @@ -55,4 +82,3 @@ class PatchEncoder(layers.Layer): }) return config -