sbb_binarize: Load each model only once

pull/17/head
Lucas Sulzbach 4 years ago
parent 0f7d4c589a
commit 45f1509dbc

@ -31,6 +31,12 @@ class SbbBinarizer:
self.start_new_session()
self.model_files = glob('%s/*.h5' % self.model_dir)
self.models = []
for model_file in self.model_files:
self.models.append(self.load_model(model_file))
def start_new_session(self):
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
@ -48,8 +54,8 @@ class SbbBinarizer:
n_classes = model.layers[len(model.layers)-1].output_shape[3]
return model, model_height, model_width, n_classes
def predict(self, model_name, img, use_patches):
model, model_height, model_width, n_classes = self.load_model(model_name)
def predict(self, model_in, img, use_patches):
model, model_height, model_width, n_classes = model_in
if use_patches:
@ -196,12 +202,11 @@ class SbbBinarizer:
raise ValueError("Must pass either a opencv2 image or an image_path")
if image_path is not None:
image = cv2.imread(image_path)
list_of_model_files = glob('%s/*.h5' % self.model_dir)
img_last = 0
for n, model_in in enumerate(list_of_model_files):
self.log.info('Predicting with model %s [%s/%s]' % (model_in, n + 1, len(list_of_model_files)))
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
res = self.predict(model_in, image, use_patches)
res = self.predict(model, image, use_patches)
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2

Loading…
Cancel
Save