remove TF1 session and GC controls, avoid repeating load_model

pull/91/head
Robert Sachunsky 1 year ago
parent a56988a35a
commit 7345f6bf67

@ -145,6 +145,8 @@ class Eynollah:
self.model_region_dir_p_ens = dir_models + "/model_ensemble_s.h5"
self.model_textline_dir = dir_models + "/model_textline_newspapers.h5"
self.model_tables = dir_models + "/model_tables_ens_mixed_new_2.h5"
self.models = {}
def _cache_images(self, image_filename=None, image_pil=None):
ret = {}
@ -255,11 +257,6 @@ class Eynollah:
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg
prediction_true = prediction_true.astype(int)
session_enhancement.close()
del model_enhancement
del session_enhancement
gc.collect()
return prediction_true
def calculate_width_height_by_columns(self, img, num_col, width_early, label_p_pred):
@ -361,16 +358,6 @@ class Eynollah:
self.logger.info("Found %s columns (%s)", num_col, label_p_pred)
session_col_classifier.close()
del model_num_classifier
del session_col_classifier
K.clear_session()
gc.collect()
img_new, _ = self.calculate_width_height_by_columns(img, num_col, width_early, label_p_pred)
if img_new.shape[1] > img.shape[1]:
@ -394,11 +381,6 @@ class Eynollah:
prediction_bin =np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
session_bin.close()
del model_bin
del session_bin
gc.collect()
prediction_bin = prediction_bin.astype(np.uint8)
img= np.copy(prediction_bin)
img_bin = np.copy(prediction_bin)
@ -428,12 +410,9 @@ class Eynollah:
img_in[0, :, :, 2] = img_1ch[:, :]
label_p_pred = model_num_classifier.predict(img_in, verbose=0)
num_col = np.argmax(label_p_pred[0]) + 1
self.logger.info("Found %s columns (%s)", num_col, label_p_pred)
session_col_classifier.close()
K.clear_session()
self.logger.info("Found %d columns (%s)", num_col, np.around(label_p_pred, decimals=5))
if dpi < DPI_THRESHOLD:
img_new, num_column_is_classified = self.calculate_width_height_by_columns(img, num_col, width_early, label_p_pred)
@ -444,9 +423,6 @@ class Eynollah:
image_res = np.copy(img)
is_image_enhanced = False
session_col_classifier.close()
self.logger.debug("exit resize_and_enhance_image_with_column_classifier")
return is_image_enhanced, img, image_res, num_col, num_column_is_classified, img_bin
@ -513,15 +489,24 @@ class Eynollah:
def start_new_session_and_model(self, model_dir):
self.logger.debug("enter start_new_session_and_model (model_dir=%s)", model_dir)
gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
#gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
#gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=7.7, allow_growth=True)
session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
#session = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
physical_devices = tf.config.list_physical_devices('GPU')
try:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
self.logger.warning("no GPU device available")
if model_dir.endswith('.h5') and Path(model_dir[:-3]).exists():
# prefer SavedModel over HDF5 format if it exists
model_dir = model_dir[:-3]
model = load_model(model_dir, compile=False)
if model_dir in self.models:
model = self.models[model_dir]
else:
model = load_model(model_dir, compile=False)
self.models[model_dir] = model
return model, session
return model, None
def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1):
self.logger.debug("enter do_prediction")
@ -640,8 +625,6 @@ class Eynollah:
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color
prediction_true = prediction_true.astype(np.uint8)
del model
gc.collect()
return prediction_true
def early_page_for_num_of_column_classification(self,img_bin):
@ -668,19 +651,15 @@ class Eynollah:
else:
box = [0, 0, img.shape[1], img.shape[0]]
croped_page, page_coord = crop_image_inside_box(box, img)
session_page.close()
del model_page
del session_page
gc.collect()
K.clear_session()
self.logger.debug("exit early_page_for_num_of_column_classification")
return croped_page, page_coord
def extract_page(self):
self.logger.debug("enter extract_page")
cont_page = []
model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
img = cv2.GaussianBlur(self.image, (5, 5), 0)
model_page, session_page = self.start_new_session_and_model(self.model_page_dir)
img_page_prediction = self.do_prediction(False, img, model_page)
imgray = cv2.cvtColor(img_page_prediction, cv2.COLOR_BGR2GRAY)
_, thresh = cv2.threshold(imgray, 0, 255, 0)
@ -707,11 +686,6 @@ class Eynollah:
box = [0, 0, img.shape[1], img.shape[0]]
croped_page, page_coord = crop_image_inside_box(box, self.image)
cont_page.append(np.array([[page_coord[2], page_coord[0]], [page_coord[3], page_coord[0]], [page_coord[3], page_coord[1]], [page_coord[2], page_coord[1]]]))
session_page.close()
del model_page
del session_page
gc.collect()
K.clear_session()
self.logger.debug("exit extract_page")
return croped_page, page_coord, cont_page
@ -807,11 +781,6 @@ class Eynollah:
prediction_regions = self.do_prediction(patches, img, model_region, marginal_of_patch_percent)
prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h)
session_region.close()
del model_region
del session_region
gc.collect()
self.logger.debug("exit extract_text_regions")
return prediction_regions, prediction_regions2
@ -1112,9 +1081,6 @@ class Eynollah:
prediction_textline_longshot = self.do_prediction(False, img, model_textline)
prediction_textline_longshot_true_size = resize_image(prediction_textline_longshot, img_h, img_w)
session_textline.close()
return prediction_textline[:, :, 0], prediction_textline_longshot_true_size[:, :, 0]
def do_work_of_slopes(self, q, poly, box_sub, boxes_per_process, textline_mask_tot, contours_per_process):
@ -1191,11 +1157,6 @@ class Eynollah:
##plt.show()
prediction_regions_org=prediction_regions_org[:,:,0]
prediction_regions_org[(prediction_regions_org[:,:]==1) & (mask_zeros_y[:,:]==1)]=0
session_region.close()
del model_region
del session_region
gc.collect()
model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p2)
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]))
@ -1203,11 +1164,6 @@ class Eynollah:
prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h )
session_region.close()
del model_region
del session_region
gc.collect()
mask_zeros2 = (prediction_regions_org2[:,:,0] == 0)
mask_lines2 = (prediction_regions_org2[:,:,0] == 3)
text_sume_early = (prediction_regions_org[:,:] == 1).sum()
@ -1247,12 +1203,6 @@ class Eynollah:
prediction_bin =np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
session_bin.close()
del model_bin
del session_bin
gc.collect()
model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens)
ratio_y=1
@ -1266,11 +1216,6 @@ class Eynollah:
prediction_regions_org=prediction_regions_org[:,:,0]
mask_lines_only=(prediction_regions_org[:,:]==3)*1
session_region.close()
del model_region
del session_region
gc.collect()
mask_texts_only=(prediction_regions_org[:,:]==1)*1
mask_images_only=(prediction_regions_org[:,:]==2)*1
@ -1289,20 +1234,12 @@ class Eynollah:
text_regions_p_true=cv2.fillPoly(text_regions_p_true,pts=polygons_of_only_texts, color=(1,1,1))
K.clear_session()
return text_regions_p_true, erosion_hurts, polygons_lines_xml
except:
if self.input_binary:
prediction_bin = np.copy(img_org)
else:
session_region.close()
del model_region
del session_region
gc.collect()
model_bin, session_bin = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_org, model_bin)
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
@ -1314,15 +1251,6 @@ class Eynollah:
prediction_bin =np.repeat(prediction_bin[:, :, np.newaxis], 3, axis=2)
session_bin.close()
del model_bin
del session_bin
gc.collect()
model_region, session_region = self.start_new_session_and_model(self.model_region_dir_p_ens)
ratio_y=1
ratio_x=1
@ -1335,11 +1263,6 @@ class Eynollah:
prediction_regions_org=prediction_regions_org[:,:,0]
#mask_lines_only=(prediction_regions_org[:,:]==3)*1
session_region.close()
del model_region
del session_region
gc.collect()
#img = resize_image(img_org, int(img_org.shape[0]*1), int(img_org.shape[1]*1))
#prediction_regions_org = self.do_prediction(True, img, model_region)
@ -1349,11 +1272,6 @@ class Eynollah:
#prediction_regions_org = prediction_regions_org[:,:,0]
#prediction_regions_org[(prediction_regions_org[:,:] == 1) & (mask_zeros_y[:,:] == 1)]=0
#session_region.close()
#del model_region
#del session_region
#gc.collect()
@ -1381,7 +1299,7 @@ class Eynollah:
text_regions_p_true = cv2.fillPoly(text_regions_p_true, pts = polygons_of_only_texts, color=(1,1,1))
erosion_hurts = True
K.clear_session()
return text_regions_p_true, erosion_hurts, polygons_lines_xml
def do_order_of_regions_full_layout(self, contours_only_text_parent, contours_only_text_parent_h, boxes, textline_mask_tot):
@ -1873,9 +1791,8 @@ class Eynollah:
img_new =np.ones((height_new,width_new,img.shape[2])).astype(float)*0
img_new[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ] =img[:,:,:]
prediction_ext = self.do_prediction(patches, img_new, model_region)
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), model_region)
pre_updown = cv2.flip(pre_updown, -1)
@ -1896,9 +1813,8 @@ class Eynollah:
img_new =np.ones((height_new,width_new,img.shape[2])).astype(float)*0
img_new[h_start:h_start+img.shape[0] ,w_start: w_start+img.shape[1], : ] =img[:,:,:]
prediction_ext = self.do_prediction(patches, img_new, model_region)
pre_updown = self.do_prediction(patches, cv2.flip(img_new[:,:,:], -1), model_region)
pre_updown = cv2.flip(pre_updown, -1)
@ -1911,12 +1827,10 @@ class Eynollah:
else:
prediction_table = np.zeros(img.shape)
img_w_half = int(img.shape[1]/2.)
pre1 = self.do_prediction(patches, img[:,0:img_w_half,:], model_region)
pre2 = self.do_prediction(patches, img[:,img_w_half:,:], model_region)
pre_full = self.do_prediction(patches, img[:,:,:], model_region)
pre_updown = self.do_prediction(patches, cv2.flip(img[:,:,:], -1), model_region)
pre_updown = cv2.flip(pre_updown, -1)
@ -1939,11 +1853,6 @@ class Eynollah:
prediction_table_erode = cv2.erode(prediction_table[:,:,0], KERNEL, iterations=20)
prediction_table_erode = cv2.dilate(prediction_table_erode, KERNEL, iterations=20)
del model_region
del session_region
gc.collect()
return prediction_table_erode.astype(np.int16)
def run_graphics_and_columns(self, text_regions_p_1, num_col_classifier, num_column_is_classified, erosion_hurts):
@ -1995,7 +1904,7 @@ class Eynollah:
self.logger.info("Resizing and enhancing image...")
is_image_enhanced, img_org, img_res, num_col_classifier, num_column_is_classified, img_bin = self.resize_and_enhance_image_with_column_classifier()
self.logger.info("Image was %senhanced.", '' if is_image_enhanced else 'not ')
K.clear_session()
scale = 1
if is_image_enhanced:
if self.allow_enhancement:
@ -2019,7 +1928,7 @@ class Eynollah:
scaler_h_textline = 1 # 1.2#1.2
scaler_w_textline = 1 # 0.9#1
textline_mask_tot_ea, _ = self.textline_contours(image_page, True, scaler_h_textline, scaler_w_textline)
K.clear_session()
if self.plotter:
self.plotter.save_plot_of_textlines(textline_mask_tot_ea, image_page)
return textline_mask_tot_ea
@ -2032,7 +1941,7 @@ class Eynollah:
if self.plotter:
self.plotter.save_deskewed_image(slope_deskew)
self.logger.info("slope_deskew: %s", slope_deskew)
self.logger.info("slope_deskew: %.2f°", slope_deskew)
return slope_deskew, slope_first
def run_marginals(self, image_page, textline_mask_tot_ea, mask_images, mask_lines, num_col_classifier, slope_deskew, text_regions_p_1, table_prediction):
@ -2081,7 +1990,6 @@ class Eynollah:
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
_, _, matrix_of_lines_ch_d, splitter_y_new_d, _ = find_number_of_columns_in_document(np.repeat(text_regions_p_1_n[:, :, np.newaxis], 3, axis=2), num_col_classifier, self.tables, pixel_lines)
K.clear_session()
self.logger.info("num_col_classifier: %s", num_col_classifier)
@ -2147,7 +2055,6 @@ class Eynollah:
contours_tables = return_contours_of_interested_region(text_regions_p, pixel_img, min_area_mar)
K.clear_session()
self.logger.debug('exit run_boxes_no_full_layout')
return polygons_of_images, img_revised_tab, text_regions_p_1_n, textline_mask_tot_d, regions_without_separators_d, boxes, boxes_d, polygons_of_marginals, contours_tables
@ -2178,8 +2085,6 @@ class Eynollah:
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
num_col_d, peaks_neg_fin_d, matrix_of_lines_ch_d, splitter_y_new_d, seperators_closeup_n_d = find_number_of_columns_in_document(np.repeat(text_regions_p_1_n[:, :, np.newaxis], 3, axis=2),num_col_classifier, self.tables, pixel_lines)
K.clear_session()
gc.collect()
if num_col_classifier>=3:
if np.abs(slope_deskew) < SLOPE_THRESHOLD:
@ -2246,21 +2151,18 @@ class Eynollah:
text_regions_p[:, :][text_regions_p[:, :] == 3] = 6
text_regions_p[:, :][text_regions_p[:, :] == 4] = 8
K.clear_session()
image_page = image_page.astype(np.uint8)
regions_fully, regions_fully_only_drop = self.extract_text_regions(image_page, True, cols=num_col_classifier)
text_regions_p[:,:][regions_fully[:,:,0]==6]=6
regions_fully_only_drop = put_drop_out_from_only_drop_model(regions_fully_only_drop, text_regions_p)
regions_fully[:, :, 0][regions_fully_only_drop[:, :, 0] == 4] = 4
K.clear_session()
# plt.imshow(regions_fully[:,:,0])
# plt.show()
regions_fully = putt_bb_of_drop_capitals_of_model_in_patches_in_layout(regions_fully)
# plt.imshow(regions_fully[:,:,0])
# plt.show()
K.clear_session()
regions_fully_np, _ = self.extract_text_regions(image_page, False, cols=num_col_classifier)
# plt.imshow(regions_fully_np[:,:,0])
# plt.show()
@ -2271,7 +2173,6 @@ class Eynollah:
# plt.imshow(regions_fully_np[:,:,0])
# plt.show()
K.clear_session()
# plt.imshow(regions_fully[:,:,0])
# plt.show()
regions_fully = boosting_headers_by_longshot_region_segmentation(regions_fully, regions_fully_np, img_only_regions)
@ -2297,7 +2198,6 @@ class Eynollah:
if not self.tables:
regions_without_separators = (text_regions_p[:, :] == 1) * 1
K.clear_session()
img_revised_tab = np.copy(text_regions_p[:, :])
polygons_of_images = return_contours_of_interested_region(img_revised_tab, 5)
self.logger.debug('exit run_boxes_full_layout')
@ -2470,7 +2370,7 @@ class Eynollah:
all_found_texline_polygons = small_textlines_to_parent_adherence2(all_found_texline_polygons, textline_mask_tot_ea, num_col_classifier)
all_found_texline_polygons_marginals, boxes_marginals, _, polygons_of_marginals, all_box_coord_marginals, _, slopes_marginals = self.get_slopes_and_deskew_new_curved(polygons_of_marginals, polygons_of_marginals, cv2.erode(textline_mask_tot_ea, kernel=KERNEL, iterations=1), image_page_rotated, boxes_marginals, text_only, num_col_classifier, scale_param, slope_deskew)
all_found_texline_polygons_marginals = small_textlines_to_parent_adherence2(all_found_texline_polygons_marginals, textline_mask_tot_ea, num_col_classifier)
K.clear_session()
if self.full_layout:
if np.abs(slope_deskew) >= SLOPE_THRESHOLD:
contours_only_text_parent_d_ordered = list(np.array(contours_only_text_parent_d_ordered, dtype=object)[index_by_text_par_con])
@ -2483,8 +2383,6 @@ class Eynollah:
self.plotter.save_plot_of_layout(text_regions_p, image_page)
self.plotter.save_plot_of_layout_all(text_regions_p, image_page)
K.clear_session()
pixel_img = 4
polygons_of_drop_capitals = return_contours_of_interested_region_by_min_size(text_regions_p, pixel_img)
all_found_texline_polygons = adhere_drop_capital_region_into_corresponding_textline(text_regions_p, polygons_of_drop_capitals, contours_only_text_parent, contours_only_text_parent_h, all_box_coord, all_box_coord_h, all_found_texline_polygons, all_found_texline_polygons_h, kernel=KERNEL, curved_line=self.curved_line)

Loading…
Cancel
Save