From 45f1509dbce4957eb4dc0df4379c301f3d378840 Mon Sep 17 00:00:00 2001 From: Lucas Sulzbach Date: Sun, 1 Nov 2020 11:26:48 -0300 Subject: [PATCH] sbb_binarize: Load each model only once --- sbb_binarize/sbb_binarize.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sbb_binarize/sbb_binarize.py b/sbb_binarize/sbb_binarize.py index 5eb342a..52d7853 100644 --- a/sbb_binarize/sbb_binarize.py +++ b/sbb_binarize/sbb_binarize.py @@ -31,6 +31,12 @@ class SbbBinarizer: self.start_new_session() + self.model_files = glob('%s/*.h5' % self.model_dir) + + self.models = [] + for model_file in self.model_files: + self.models.append(self.load_model(model_file)) + def start_new_session(self): config = tf.ConfigProto() config.gpu_options.allow_growth = True @@ -48,8 +54,8 @@ class SbbBinarizer: n_classes = model.layers[len(model.layers)-1].output_shape[3] return model, model_height, model_width, n_classes - def predict(self, model_name, img, use_patches): - model, model_height, model_width, n_classes = self.load_model(model_name) + def predict(self, model_in, img, use_patches): + model, model_height, model_width, n_classes = model_in if use_patches: @@ -196,12 +202,11 @@ class SbbBinarizer: raise ValueError("Must pass either a opencv2 image or an image_path") if image_path is not None: image = cv2.imread(image_path) - list_of_model_files = glob('%s/*.h5' % self.model_dir) img_last = 0 - for n, model_in in enumerate(list_of_model_files): - self.log.info('Predicting with model %s [%s/%s]' % (model_in, n + 1, len(list_of_model_files))) + for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): + self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) - res = self.predict(model_in, image, use_patches) + res = self.predict(model, image, use_patches) img_fin = np.zeros((res.shape[0], res.shape[1], 3)) res[:, :][res[:, :] == 0] = 2