mirror of
https://github.com/qurator-spk/sbb_binarization.git
synced 2025-06-09 12:19:56 +02:00
sbb_binarize: Load each model only once
This commit is contained in:
parent
0f7d4c589a
commit
45f1509dbc
1 changed files with 11 additions and 6 deletions
|
@ -31,6 +31,12 @@ class SbbBinarizer:
|
|||
|
||||
self.start_new_session()
|
||||
|
||||
self.model_files = glob('%s/*.h5' % self.model_dir)
|
||||
|
||||
self.models = []
|
||||
for model_file in self.model_files:
|
||||
self.models.append(self.load_model(model_file))
|
||||
|
||||
def start_new_session(self):
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
@ -48,8 +54,8 @@ class SbbBinarizer:
|
|||
n_classes = model.layers[len(model.layers)-1].output_shape[3]
|
||||
return model, model_height, model_width, n_classes
|
||||
|
||||
def predict(self, model_name, img, use_patches):
|
||||
model, model_height, model_width, n_classes = self.load_model(model_name)
|
||||
def predict(self, model_in, img, use_patches):
|
||||
model, model_height, model_width, n_classes = model_in
|
||||
|
||||
if use_patches:
|
||||
|
||||
|
@ -196,12 +202,11 @@ class SbbBinarizer:
|
|||
raise ValueError("Must pass either a opencv2 image or an image_path")
|
||||
if image_path is not None:
|
||||
image = cv2.imread(image_path)
|
||||
list_of_model_files = glob('%s/*.h5' % self.model_dir)
|
||||
img_last = 0
|
||||
for n, model_in in enumerate(list_of_model_files):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_in, n + 1, len(list_of_model_files)))
|
||||
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
|
||||
self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files)))
|
||||
|
||||
res = self.predict(model_in, image, use_patches)
|
||||
res = self.predict(model, image, use_patches)
|
||||
|
||||
img_fin = np.zeros((res.shape[0], res.shape[1], 3))
|
||||
res[:, :][res[:, :] == 0] = 2
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue