|
|
|
@ -29,6 +29,14 @@ class SbbBinarizer:
|
|
|
|
|
self.model_dir = model_dir
|
|
|
|
|
self.log = logger if logger else logging.getLogger('SbbBinarizer')
|
|
|
|
|
|
|
|
|
|
self.start_new_session()
|
|
|
|
|
|
|
|
|
|
self.model_files = glob('%s/*.h5' % self.model_dir)
|
|
|
|
|
|
|
|
|
|
self.models = []
|
|
|
|
|
for model_file in self.model_files:
|
|
|
|
|
self.models.append(self.load_model(model_file))
|
|
|
|
|
|
|
|
|
|
def start_new_session(self):
|
|
|
|
|
config = tf.ConfigProto()
|
|
|
|
|
config.gpu_options.allow_growth = True
|
|
|
|
@ -46,8 +54,8 @@ class SbbBinarizer:
|
|
|
|
|
n_classes = model.layers[len(model.layers)-1].output_shape[3]
|
|
|
|
|
return model, model_height, model_width, n_classes
|
|
|
|
|
|
|
|
|
|
def predict(self, model_name, img, use_patches):
|
|
|
|
|
model, model_height, model_width, n_classes = self.load_model(model_name)
|
|
|
|
|
def predict(self, model_in, img, use_patches):
|
|
|
|
|
model, model_height, model_width, n_classes = model_in
|
|
|
|
|
|
|
|
|
|
if use_patches:
|
|
|
|
|
|
|
|
|
@ -194,13 +202,11 @@ class SbbBinarizer:
|
|
|
|
|
raise ValueError("Must pass either a opencv2 image or an image_path")
|
|
|
|
|
if image_path is not None:
|
|
|
|
|
image = cv2.imread(image_path)
|
|
|
|
|
self.start_new_session()
|
|
|
|
|
list_of_model_files = glob('%s/*.h5' % self.model_dir)
|
|
|
|
|
img_last = 0
|
|
|
|
|
for n, model_in in enumerate(list_of_model_files):
|
|
|
|
|
self.log.info('Predicting with model %s [%s/%s]' % (model_in, n + 1, len(list_of_model_files)))
|
|
|
|
|
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
|
|
|
|
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
|
|
|
|
|
|
|
|
|
res = self.predict(model_in, image, use_patches)
|
|
|
|
|
res = self.predict(model, image, use_patches)
|
|
|
|
|
|
|
|
|
|
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
|
|
|
|
|
res[:, :][res[:, :] == 0] = 2
|
|
|
|
|