remove session methods and redundant model loaders

pull/148/head^2
Robert Sachunsky 2 weeks ago
parent dd51f900b9
commit 1a0a1cb00b

@ -323,50 +323,52 @@ class Eynollah:
self.model_textline_dir = dir_models + "/modelens_textline_0_1__2_4_16092024"
if self.ocr:
self.model_ocr_dir = dir_models + "/trocr_model_ens_of_3_checkpoints_201124"
if self.tables:
if self.light_version:
self.model_table_dir = dir_models + "/modelens_table_0t4_201124"
else:
self.model_table_dir = dir_models + "/eynollah-tables_20210319"
self.models = {}
if dir_in:
# as in start_new_session:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
set_session(session)
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
if self.extract_only_images:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction)
# #gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
# #gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True)
# #session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
# config = tf.compat.v1.ConfigProto()
# config.gpu_options.allow_growth = True
# #session = tf.InteractiveSession()
# session = tf.compat.v1.Session(config=config)
# set_session(session)
try:
for device in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(device, True)
except:
self.logger.warning("no GPU device available")
self.model_page = self.our_load_model(self.model_page_dir)
self.model_classifier = self.our_load_model(self.model_dir_of_col_classifier)
self.model_bin = self.our_load_model(self.model_dir_of_binarization)
if self.extract_only_images:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light_only_images_extraction)
else:
self.model_textline = self.our_load_model(self.model_textline_dir)
if self.light_version:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light)
self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np)
else:
self.model_textline = self.our_load_model(self.model_textline_dir)
if self.light_version:
self.model_region = self.our_load_model(self.model_region_dir_p_ens_light)
self.model_region_1_2 = self.our_load_model(self.model_region_dir_p_1_2_sp_np)
else:
self.model_region = self.our_load_model(self.model_region_dir_p_ens)
self.model_region_p2 = self.our_load_model(self.model_region_dir_p2)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new)
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
if self.reading_order_machine_based:
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
if self.ocr:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
if self.tables:
self.model_table = self.our_load_model(self.model_table_dir)
self.ls_imgs = os.listdir(self.dir_in)
self.model_region = self.our_load_model(self.model_region_dir_p_ens)
self.model_region_p2 = self.our_load_model(self.model_region_dir_p2)
self.model_enhancement = self.our_load_model(self.model_dir_of_enhancement)
###self.model_region_fl_new = self.our_load_model(self.model_region_dir_fully_new)
self.model_region_fl_np = self.our_load_model(self.model_region_dir_fully_np)
self.model_region_fl = self.our_load_model(self.model_region_dir_fully)
if self.reading_order_machine_based:
self.model_reading_order = self.our_load_model(self.model_reading_order_dir)
if self.ocr:
self.model_ocr = VisionEncoderDecoderModel.from_pretrained(self.model_ocr_dir)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#("microsoft/trocr-base-printed")#("microsoft/trocr-base-handwritten")
self.processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
if self.tables:
self.model_table = self.our_load_model(self.model_table_dir)
def _cache_images(self, image_filename=None, image_pil=None):
ret = {}
@ -421,8 +423,6 @@ class Eynollah:
def predict_enhancement(self, img):
self.logger.debug("enter predict_enhancement")
if not self.dir_in:
self.model_enhancement, _ = self.start_new_session_and_model(self.model_dir_of_enhancement)
img_height_model = self.model_enhancement.layers[-1].output_shape[1]
img_width_model = self.model_enhancement.layers[-1].output_shape[2]
@ -619,9 +619,6 @@ class Eynollah:
img = self.imread()
_, page_coord = self.early_page_for_num_of_column_classification(img)
if not self.dir_in:
self.model_classifier, _ = self.start_new_session_and_model(self.model_dir_of_col_classifier)
if self.input_binary:
img_in = np.copy(img)
@ -662,9 +659,6 @@ class Eynollah:
self.logger.info("Detected %s DPI", dpi)
if self.input_binary:
img = self.imread()
if not self.dir_in:
self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2).astype(np.uint8)
@ -673,17 +667,14 @@ class Eynollah:
else:
img = self.imread()
img_bin = None
width_early = img.shape[1]
t1 = time.time()
_, page_coord = self.early_page_for_num_of_column_classification(img_bin)
self.image_page_org_size = img[page_coord[0] : page_coord[1], page_coord[2] : page_coord[3], :]
self.page_coord = page_coord
if not self.dir_in:
self.model_classifier, _ = self.start_new_session_and_model(self.model_dir_of_col_classifier)
if self.num_col_upper and not self.num_col_lower:
num_col = self.num_col_upper
label_p_pred = [np.ones(6)]
@ -823,43 +814,6 @@ class Eynollah:
self.writer.height_org = self.height_org
self.writer.width_org = self.width_org
def start_new_session_and_model_old(self, model_dir):
self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.InteractiveSession()
model = load_model(model_dir, compile=False)
return model, session
def start_new_session_and_model(self, model_dir):
self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
#gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
#gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True)
#session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
physical_devices = tf.config.list_physical_devices('GPU')
try:
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)
except:
self.logger.warning("no GPU device available")
if model_dir.endswith('.h5') and Path(model_dir[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_dir = model_dir[:-3]
if model_dir in self.models:
model = self.models[model_dir]
else:
try:
model = load_model(model_dir, compile=False)
except:
model = load_model(model_dir , compile=False, custom_objects={
"PatchEncoder": PatchEncoder, "Patches": Patches})
self.models[model_dir] = model
return model, None
def do_prediction(
self, patches, img, model,
n_batch_inference=1, marginal_of_patch_percent=0.1,
@ -1397,9 +1351,6 @@ class Eynollah:
self.logger.debug("enter extract_page")
cont_page = []
if not self.ignore_page_extraction:
if not self.dir_in:
self.model_page, _ = self.start_new_session_and_model(self.model_page_dir)
img = cv2.GaussianBlur(self.image, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page)
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
@ -1447,8 +1398,6 @@ class Eynollah:
img = np.copy(img_bin).astype(np.uint8)
else:
img = self.imread()
if not self.dir_in:
self.model_page, _ = self.start_new_session_and_model(self.model_page_dir)
img = cv2.GaussianBlur(img, (5, 5), 0)
img_page_prediction = self.do_prediction(False, img, self.model_page)
@ -1476,11 +1425,6 @@ class Eynollah:
self.logger.debug("enter extract_text_regions")
img_height_h = img.shape[0]
img_width_h = img.shape[1]
if not self.dir_in:
if patches:
self.model_region_fl, _ = self.start_new_session_and_model(self.model_region_dir_fully)
else:
self.model_region_fl_np, _ = self.start_new_session_and_model(self.model_region_dir_fully_np)
model_region = self.model_region_fl if patches else self.model_region_fl_np
if self.light_version:
@ -1512,11 +1456,6 @@ class Eynollah:
self.logger.debug("enter extract_text_regions")
img_height_h = img.shape[0]
img_width_h = img.shape[1]
if not self.dir_in:
if patches:
self.model_region_fl, _ = self.start_new_session_and_model(self.model_region_dir_fully)
else:
self.model_region_fl_np, _ = self.start_new_session_and_model(self.model_region_dir_fully_np)
model_region = self.model_region_fl if patches else self.model_region_fl_np
if not patches:
@ -1647,8 +1586,6 @@ class Eynollah:
def textline_contours(self, img, use_patches, scaler_h, scaler_w, num_col_classifier=None):
self.logger.debug('enter textline_contours')
if not self.dir_in:
self.model_textline, _ = self.start_new_session_and_model(self.model_textline_dir)
#img = img.astype(np.uint8)
img_org = np.copy(img)
@ -1750,9 +1687,6 @@ class Eynollah:
img_h_new = int(img.shape[0] / float(img.shape[1]) * img_w_new)
img_resized = resize_image(img,img_h_new, img_w_new )
if not self.dir_in:
self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens_light_only_images_extraction)
prediction_regions_org = self.do_prediction_new_concept(True, img_resized, self.model_region)
prediction_regions_org = resize_image(prediction_regions_org,img_height_h, img_width_h )
@ -1841,7 +1775,6 @@ class Eynollah:
img_height_h = img_org.shape[0]
img_width_h = img_org.shape[1]
#model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens)
#print(num_col_classifier,'num_col_classifier')
if num_col_classifier == 1:
@ -1864,8 +1797,6 @@ class Eynollah:
#if self.input_binary:
#img_bin = np.copy(img_resized)
###if (not self.input_binary and self.full_layout) or (not self.input_binary and num_col_classifier >= 30):
###if not self.dir_in:
###self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
###prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
####print("inside bin ", time.time()-t_bin)
@ -1881,8 +1812,6 @@ class Eynollah:
###else:
###img_bin = np.copy(img_resized)
if self.ocr and not self.input_binary:
if not self.dir_in:
self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
prediction_bin = 255 * (prediction_bin[:,:,0] == 0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
@ -1905,12 +1834,7 @@ class Eynollah:
#plt.show()
if not skip_layout_and_reading_order:
#print("inside 2 ", time.time()-t_in)
if not self.dir_in:
self.model_region_1_2, _ = self.start_new_session_and_model(self.model_region_dir_p_1_2_sp_np)
##self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens_light)
if num_col_classifier == 1 or num_col_classifier == 2:
model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_1_2_sp_np)
if self.image_org.shape[0]/self.image_org.shape[1] > 2.5:
self.logger.debug("resized to %dx%d for %d cols",
img_resized.shape[1], img_resized.shape[0], num_col_classifier)
@ -2008,9 +1932,6 @@ class Eynollah:
img_org = np.copy(img)
img_height_h = img_org.shape[0]
img_width_h = img_org.shape[1]
if not self.dir_in:
self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens)
ratio_y=1.3
ratio_x=1
@ -2037,11 +1958,8 @@ class Eynollah:
prediction_regions_org=prediction_regions_org[:,:,0]
prediction_regions_org[(prediction_regions_org[:,:]==1) & (mask_zeros_y[:,:]==1)]=0
if not self.dir_in:
self.model_region_p2, _ = self.start_new_session_and_model(self.model_region_dir_p2)
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]))
prediction_regions_org2 = self.do_prediction(True, img, self.model_region_p2, marginal_of_patch_percent=0.2)
prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h )
@ -2066,15 +1984,11 @@ class Eynollah:
if self.input_binary:
prediction_bin = np.copy(img_org)
else:
if not self.dir_in:
self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
if not self.dir_in:
self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens)
ratio_y=1
ratio_x=1
@ -2107,17 +2021,10 @@ class Eynollah:
except:
if self.input_binary:
prediction_bin = np.copy(img_org)
if not self.dir_in:
self.model_bin, _ = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
prediction_bin = 255 * (prediction_bin[:,:,0]==0)
prediction_bin = np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
if not self.dir_in:
self.model_region, _ = self.start_new_session_and_model(self.model_region_dir_p_ens)
else:
prediction_bin = np.copy(img_org)
ratio_y=1
@ -2747,10 +2654,6 @@ class Eynollah:
img_org = np.copy(img)
img_height_h = img_org.shape[0]
img_width_h = img_org.shape[1]
if not self.dir_in:
self.model_table, _ = self.start_new_session_and_model(self.model_table_dir)
patches = False
if self.light_version:
prediction_table = self.do_prediction_new_concept(patches, img, self.model_table)
@ -3386,8 +3289,12 @@ class Eynollah:
return (polygons_of_images, img_revised_tab, text_regions_p_1_n, textline_mask_tot_d,
regions_without_separators_d, regions_fully, regions_without_separators,
polygons_of_marginals, contours_tables)
def our_load_model(self, model_file):
@staticmethod
def our_load_model(model_file):
if model_file.endswith('.h5') and Path(model_file[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_file = model_file[:-3]
try:
model = load_model(model_file, compile=False)
except:
@ -3438,9 +3345,6 @@ class Eynollah:
img_header_and_sep = resize_image(img_header_and_sep, height1, width1)
img_poly = resize_image(img_poly, height3, width3)
if not self.dir_in:
self.model_reading_order, _ = self.start_new_session_and_model(self.model_reading_order_dir)
inference_bs = 3
input_1 = np.zeros((inference_bs, height1, width1, 3))
ordered = [list(range(len(co_text_all)))]

Loading…
Cancel
Save