mirror of
				https://github.com/qurator-spk/sbb_binarization.git
				synced 2025-10-30 00:54:14 +01:00 
			
		
		
		
	
						commit
						039acd0610
					
				
					 1 changed files with 13 additions and 7 deletions
				
			
		|  | @ -29,6 +29,14 @@ class SbbBinarizer: | ||||||
|         self.model_dir = model_dir |         self.model_dir = model_dir | ||||||
|         self.log = logger if logger else logging.getLogger('SbbBinarizer') |         self.log = logger if logger else logging.getLogger('SbbBinarizer') | ||||||
| 
 | 
 | ||||||
|  |         self.start_new_session() | ||||||
|  | 
 | ||||||
|  |         self.model_files = glob('%s/*.h5' % self.model_dir) | ||||||
|  | 
 | ||||||
|  |         self.models = [] | ||||||
|  |         for model_file in self.model_files: | ||||||
|  |             self.models.append(self.load_model(model_file)) | ||||||
|  | 
 | ||||||
|     def start_new_session(self): |     def start_new_session(self): | ||||||
|         config = tf.ConfigProto() |         config = tf.ConfigProto() | ||||||
|         config.gpu_options.allow_growth = True |         config.gpu_options.allow_growth = True | ||||||
|  | @ -46,8 +54,8 @@ class SbbBinarizer: | ||||||
|         n_classes = model.layers[len(model.layers)-1].output_shape[3] |         n_classes = model.layers[len(model.layers)-1].output_shape[3] | ||||||
|         return model, model_height, model_width, n_classes |         return model, model_height, model_width, n_classes | ||||||
| 
 | 
 | ||||||
|     def predict(self, model_name, img, use_patches): |     def predict(self, model_in, img, use_patches): | ||||||
|         model, model_height, model_width, n_classes = self.load_model(model_name) |         model, model_height, model_width, n_classes = model_in | ||||||
| 
 | 
 | ||||||
|         if use_patches: |         if use_patches: | ||||||
| 
 | 
 | ||||||
|  | @ -194,13 +202,11 @@ class SbbBinarizer: | ||||||
|             raise ValueError("Must pass either a opencv2 image or an image_path") |             raise ValueError("Must pass either a opencv2 image or an image_path") | ||||||
|         if image_path is not None: |         if image_path is not None: | ||||||
|             image = cv2.imread(image_path) |             image = cv2.imread(image_path) | ||||||
|         self.start_new_session() |  | ||||||
|         list_of_model_files = glob('%s/*.h5' % self.model_dir) |  | ||||||
|         img_last = 0 |         img_last = 0 | ||||||
|         for n, model_in in enumerate(list_of_model_files): |         for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): | ||||||
|             self.log.info('Predicting with model %s [%s/%s]' % (model_in, n + 1, len(list_of_model_files))) |             self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) | ||||||
| 
 | 
 | ||||||
|             res = self.predict(model_in, image, use_patches) |             res = self.predict(model, image, use_patches) | ||||||
| 
 | 
 | ||||||
|             img_fin = np.zeros((res.shape[0], res.shape[1], 3)) |             img_fin = np.zeros((res.shape[0], res.shape[1], 3)) | ||||||
|             res[:, :][res[:, :] == 0] = 2 |             res[:, :][res[:, :] == 0] = 2 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue