mirror of
				https://github.com/qurator-spk/eynollah.git
				synced 2025-11-04 11:44:15 +01:00 
			
		
		
		
	rfct: move all tensorflow/keras imports and hacks to utils.tf
This commit is contained in:
		
							parent
							
								
									b15b1bdcd5
								
							
						
					
					
						commit
						8c4bfa229f
					
				
					 2 changed files with 38 additions and 27 deletions
				
			
		| 
						 | 
				
			
			@ -8,11 +8,10 @@ document layout analysis (segmentation) with output in PAGE-XML
 | 
			
		|||
 | 
			
		||||
from logging import Logger
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
from os import listdir
 | 
			
		||||
from os.path import join
 | 
			
		||||
import time
 | 
			
		||||
from typing import Optional
 | 
			
		||||
import warnings
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from multiprocessing import Process, Queue, cpu_count
 | 
			
		||||
from PIL.Image import Image
 | 
			
		||||
| 
						 | 
				
			
			@ -23,19 +22,7 @@ from scipy.signal import find_peaks
 | 
			
		|||
import matplotlib.pyplot as plt
 | 
			
		||||
from scipy.ndimage import gaussian_filter1d
 | 
			
		||||
 | 
			
		||||
from qurator.eynollah.utils.keras import PatchEncoder, Patches
 | 
			
		||||
 | 
			
		||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
 | 
			
		||||
stderr = sys.stderr
 | 
			
		||||
sys.stderr = open(os.devnull, "w")
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
sys.stderr = stderr
 | 
			
		||||
tf.get_logger().setLevel("ERROR")
 | 
			
		||||
warnings.filterwarnings("ignore")
 | 
			
		||||
 | 
			
		||||
load_model = tf.keras.models.load_model
 | 
			
		||||
# use tf1 compatibility for keras backend
 | 
			
		||||
set_session = tf.compat.v1.keras.backend.set_session
 | 
			
		||||
from qurator.eynollah.utils.tf import tf, PatchEncoder, Patches
 | 
			
		||||
 | 
			
		||||
from .utils.contour import (
 | 
			
		||||
    filter_contours_area_of_image,
 | 
			
		||||
| 
						 | 
				
			
			@ -182,7 +169,7 @@ class Eynollah():
 | 
			
		|||
            self.model_textline_dir = dir_models + "/eynollah-textline_20210425"
 | 
			
		||||
        self.model_tables = dir_models + "/eynollah-tables_20210319"
 | 
			
		||||
        
 | 
			
		||||
        self.models = {}
 | 
			
		||||
        self.models : dict[str, tf.keras.Model] = {}
 | 
			
		||||
        
 | 
			
		||||
        if dir_in and light_version:
 | 
			
		||||
            config = tf.compat.v1.ConfigProto()
 | 
			
		||||
| 
						 | 
				
			
			@ -198,7 +185,7 @@ class Eynollah():
 | 
			
		|||
            self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
 | 
			
		||||
            self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
 | 
			
		||||
            
 | 
			
		||||
            self.ls_imgs  = os.listdir(self.dir_in)
 | 
			
		||||
            self.ls_imgs  = listdir(self.dir_in)
 | 
			
		||||
            
 | 
			
		||||
        if dir_in and not light_version:
 | 
			
		||||
            config = tf.compat.v1.ConfigProto()
 | 
			
		||||
| 
						 | 
				
			
			@ -216,7 +203,7 @@ class Eynollah():
 | 
			
		|||
            self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
 | 
			
		||||
            self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
 | 
			
		||||
            
 | 
			
		||||
            self.ls_imgs  = os.listdir(self.dir_in)
 | 
			
		||||
            self.ls_imgs  = listdir(self.dir_in)
 | 
			
		||||
            
 | 
			
		||||
        
 | 
			
		||||
    def _cache_images(self, image_filename=None, image_pil=None):
 | 
			
		||||
| 
						 | 
				
			
			@ -586,9 +573,8 @@ class Eynollah():
 | 
			
		|||
        self.writer.scale_x = self.scale_x
 | 
			
		||||
        self.writer.height_org = self.height_org
 | 
			
		||||
        self.writer.width_org = self.width_org
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def start_new_session_and_model(self, model_dir):
 | 
			
		||||
    def start_new_session_and_model(self, model_dir) -> tf.keras.Model:
 | 
			
		||||
        self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
 | 
			
		||||
        #gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
 | 
			
		||||
        #gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True)
 | 
			
		||||
| 
						 | 
				
			
			@ -613,8 +599,7 @@ class Eynollah():
 | 
			
		|||
                model = load_model(model_dir , compile=False, custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches})
 | 
			
		||||
                self.models[model_dir] = model
 | 
			
		||||
 | 
			
		||||
        # FIXME: why?
 | 
			
		||||
        return model, None
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1):
 | 
			
		||||
        self.logger.debug("enter do_prediction")
 | 
			
		||||
| 
						 | 
				
			
			@ -910,7 +895,7 @@ class Eynollah():
 | 
			
		|||
            img = cv2.GaussianBlur(self.image, (5, 5), 0)
 | 
			
		||||
            
 | 
			
		||||
            if not self.dir_in:
 | 
			
		||||
                model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
 | 
			
		||||
                model_page = self.start_new_session_and_model(self.model_page_dir)
 | 
			
		||||
                
 | 
			
		||||
            if not self.dir_in:
 | 
			
		||||
                img_page_prediction = self.do_prediction(False, img, model_page)
 | 
			
		||||
| 
						 | 
				
			
			@ -958,7 +943,7 @@ class Eynollah():
 | 
			
		|||
            else:
 | 
			
		||||
                img = self.imread()
 | 
			
		||||
            if not self.dir_in:
 | 
			
		||||
                model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
 | 
			
		||||
                model_page = self.start_new_session_and_model(self.model_page_dir)
 | 
			
		||||
            img = cv2.GaussianBlur(img, (5, 5), 0)
 | 
			
		||||
            
 | 
			
		||||
            if self.dir_in:
 | 
			
		||||
| 
						 | 
				
			
			@ -2774,7 +2759,7 @@ class Eynollah():
 | 
			
		|||
        for img_name in self.ls_imgs:
 | 
			
		||||
            t0 = time.time()
 | 
			
		||||
            if self.dir_in:
 | 
			
		||||
                self.reset_file_name_dir(os.path.join(self.dir_in,img_name))
 | 
			
		||||
                self.reset_file_name_dir(join(self.dir_in,img_name))
 | 
			
		||||
            
 | 
			
		||||
            img_res, is_image_enhanced, num_col_classifier, num_column_is_classified = self.run_enhancement(self.light_version)
 | 
			
		||||
            self.logger.info("Enhancing took %.1fs ", time.time() - t0)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,10 +1,37 @@
 | 
			
		|||
import warnings
 | 
			
		||||
from os import environ, devnull
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
 | 
			
		||||
stderr = sys.stderr
 | 
			
		||||
sys.stderr = open(devnull, "w")
 | 
			
		||||
import tensorflow as tf
 | 
			
		||||
sys.stderr = stderr
 | 
			
		||||
tf.get_logger().setLevel("ERROR")
 | 
			
		||||
warnings.filterwarnings("ignore")
 | 
			
		||||
 | 
			
		||||
layers = tf.keras.layers
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    'PatchEncoder',
 | 
			
		||||
    'Patches',
 | 
			
		||||
    'load_model',
 | 
			
		||||
    'set_session',
 | 
			
		||||
    'tf',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
PROJECTION_DIM = 64
 | 
			
		||||
PATCH_SIZE = 1
 | 
			
		||||
NUM_PATCHES = 21*21 #14*14#28*28#14*14#28*28
 | 
			
		||||
 | 
			
		||||
def load_model(*args, **kwargs) -> tf.keras.Model:
 | 
			
		||||
    ret = tf.keras.models.load_model(*args, **kwargs)
 | 
			
		||||
    assert isinstance(ret, tf.keras.Model)
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
# use tf1 compatibility for keras backend
 | 
			
		||||
set_session = tf.compat.v1.keras.backend.set_session
 | 
			
		||||
 | 
			
		||||
class Patches(layers.Layer):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			@ -55,4 +82,3 @@ class PatchEncoder(layers.Layer):
 | 
			
		|||
        })
 | 
			
		||||
        return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue