mirror of
https://github.com/qurator-spk/eynollah.git
synced 2025-06-09 12:19:54 +02:00
rfct: move all tensorflow/keras imports and hacks to utils.tf
This commit is contained in:
parent
b15b1bdcd5
commit
8c4bfa229f
2 changed files with 38 additions and 27 deletions
|
@ -8,11 +8,10 @@ document layout analysis (segmentation) with output in PAGE-XML
|
|||
|
||||
from logging import Logger
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from os import listdir
|
||||
from os.path import join
|
||||
import time
|
||||
from typing import Optional
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from multiprocessing import Process, Queue, cpu_count
|
||||
from PIL.Image import Image
|
||||
|
@ -23,19 +22,7 @@ from scipy.signal import find_peaks
|
|||
import matplotlib.pyplot as plt
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
from qurator.eynollah.utils.keras import PatchEncoder, Patches
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
stderr = sys.stderr
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
import tensorflow as tf
|
||||
sys.stderr = stderr
|
||||
tf.get_logger().setLevel("ERROR")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
load_model = tf.keras.models.load_model
|
||||
# use tf1 compatibility for keras backend
|
||||
set_session = tf.compat.v1.keras.backend.set_session
|
||||
from qurator.eynollah.utils.tf import tf, PatchEncoder, Patches
|
||||
|
||||
from .utils.contour import (
|
||||
filter_contours_area_of_image,
|
||||
|
@ -182,7 +169,7 @@ class Eynollah():
|
|||
self.model_textline_dir = dir_models + "/eynollah-textline_20210425"
|
||||
self.model_tables = dir_models + "/eynollah-tables_20210319"
|
||||
|
||||
self.models = {}
|
||||
self.models : dict[str, tf.keras.Model] = {}
|
||||
|
||||
if dir_in and light_version:
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
|
@ -198,7 +185,7 @@ class Eynollah():
|
|||
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
|
||||
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
|
||||
|
||||
self.ls_imgs = os.listdir(self.dir_in)
|
||||
self.ls_imgs = listdir(self.dir_in)
|
||||
|
||||
if dir_in and not light_version:
|
||||
config = tf.compat.v1.ConfigProto()
|
||||
|
@ -216,7 +203,7 @@ class Eynollah():
|
|||
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
|
||||
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
|
||||
|
||||
self.ls_imgs = os.listdir(self.dir_in)
|
||||
self.ls_imgs = listdir(self.dir_in)
|
||||
|
||||
|
||||
def _cache_images(self, image_filename=None, image_pil=None):
|
||||
|
@ -586,9 +573,8 @@ class Eynollah():
|
|||
self.writer.scale_x = self.scale_x
|
||||
self.writer.height_org = self.height_org
|
||||
self.writer.width_org = self.width_org
|
||||
|
||||
|
||||
def start_new_session_and_model(self, model_dir):
|
||||
def start_new_session_and_model(self, model_dir) -> tf.keras.Model:
|
||||
self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
|
||||
#gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
|
||||
#gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True)
|
||||
|
@ -613,8 +599,7 @@ class Eynollah():
|
|||
model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
|
||||
self.models[model_dir] = model
|
||||
|
||||
# FIXME: why?
|
||||
return model, None
|
||||
return model
|
||||
|
||||
def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1):
|
||||
self.logger.debug("enter do_prediction")
|
||||
|
@ -910,7 +895,7 @@ class Eynollah():
|
|||
img = cv2.GaussianBlur(self.image, (5, 5), 0)
|
||||
|
||||
if not self.dir_in:
|
||||
model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
|
||||
model_page = self.start_new_session_and_model(self.model_page_dir)
|
||||
|
||||
if not self.dir_in:
|
||||
img_page_prediction = self.do_prediction(False, img, model_page)
|
||||
|
@ -958,7 +943,7 @@ class Eynollah():
|
|||
else:
|
||||
img = self.imread()
|
||||
if not self.dir_in:
|
||||
model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
|
||||
model_page = self.start_new_session_and_model(self.model_page_dir)
|
||||
img = cv2.GaussianBlur(img, (5, 5), 0)
|
||||
|
||||
if self.dir_in:
|
||||
|
@ -2774,7 +2759,7 @@ class Eynollah():
|
|||
for img_name in self.ls_imgs:
|
||||
t0 = time.time()
|
||||
if self.dir_in:
|
||||
self.reset_file_name_dir(os.path.join(self.dir_in,img_name))
|
||||
self.reset_file_name_dir(join(self.dir_in,img_name))
|
||||
|
||||
img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(self.light_version)
|
||||
self.logger.info("Enhancing took %.1fs ", time.time() - t0)
|
||||
|
|
|
@ -1,10 +1,37 @@
|
|||
import warnings
|
||||
from os import environ, devnull
|
||||
import sys
|
||||
|
||||
environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
stderr = sys.stderr
|
||||
sys.stderr = open(devnull, "w")
|
||||
import tensorflow as tf
|
||||
sys.stderr = stderr
|
||||
tf.get_logger().setLevel("ERROR")
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
layers = tf.keras.layers
|
||||
|
||||
__all__ = [
|
||||
'PatchEncoder',
|
||||
'Patches',
|
||||
'load_model',
|
||||
'set_session',
|
||||
'tf',
|
||||
]
|
||||
|
||||
PROJECTION_DIM = 64
|
||||
PATCH_SIZE = 1
|
||||
NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28
|
||||
|
||||
def load_model(*args, **kwargs) -> tf.keras.Model:
|
||||
ret = tf.keras.models.load_model(*args, **kwargs)
|
||||
assert isinstance(ret, tf.keras.Model)
|
||||
return ret
|
||||
|
||||
# use tf1 compatibility for keras backend
|
||||
set_session = tf.compat.v1.keras.backend.set_session
|
||||
|
||||
class Patches(layers.Layer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -55,4 +82,3 @@ class PatchEncoder(layers.Layer):
|
|||
})
|
||||
return config
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue