Merge pull request #17 from sulzbals/ram-consumption

Fix space leak
pull/18/head
Konstantin Baierer 4 years ago committed by GitHub
commit 039acd0610
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,6 +29,14 @@ class SbbBinarizer:
self.model_dir = model_dir
self.log = logger if logger else logging.getLogger('SbbBinarizer')
self.start_new_session()
self.model_files = glob('%s/*.h5' % self.model_dir)
self.models = []
for model_file in self.model_files:
self.models.append(self.load_model(model_file))
def start_new_session(self):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
@ -46,8 +54,8 @@ class SbbBinarizer:
n_classes = model.layers[len(model.layers)-1].output_shape[3]
return model, model_height, model_width, n_classes
def predict(self, model_name, img, use_patches):
model, model_height, model_width, n_classes = self.load_model(model_name)
def predict(self, model_in, img, use_patches):
model, model_height, model_width, n_classes = model_in
if use_patches:
@ -194,13 +202,11 @@ class SbbBinarizer:
raise ValueError("Must pass either a opencv2 image or an image_path")
if image_path is not None:
image = cv2.imread(image_path)
self.start_new_session()
list_of_model_files = glob('%s/*.h5' % self.model_dir)
img_last = 0
for n, model_in in enumerate(list_of_model_files):
self.log.info('Predicting with model %s [%s/%s]' % (model_in, n + 1, len(list_of_model_files)))
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
res = self.predict(model_in, image, use_patches)
res = self.predict(model, image, use_patches)
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2

Loading…
Cancel
Save