mirror of
				https://github.com/qurator-spk/sbb_binarization.git
				synced 2025-10-30 00:54:14 +01:00 
			
		
		
		
	
						commit
						039acd0610
					
				
					 1 changed files with 13 additions and 7 deletions
				
			
		|  | @ -29,6 +29,14 @@ class SbbBinarizer: | |||
|         self.model_dir = model_dir | ||||
|         self.log = logger if logger else logging.getLogger('SbbBinarizer') | ||||
| 
 | ||||
|         self.start_new_session() | ||||
| 
 | ||||
|         self.model_files = glob('%s/*.h5' % self.model_dir) | ||||
| 
 | ||||
|         self.models = [] | ||||
|         for model_file in self.model_files: | ||||
|             self.models.append(self.load_model(model_file)) | ||||
| 
 | ||||
|     def start_new_session(self): | ||||
|         config = tf.ConfigProto() | ||||
|         config.gpu_options.allow_growth = True | ||||
|  | @ -46,8 +54,8 @@ class SbbBinarizer: | |||
|         n_classes = model.layers[len(model.layers)-1].output_shape[3] | ||||
|         return model, model_height, model_width, n_classes | ||||
| 
 | ||||
|     def predict(self, model_name, img, use_patches): | ||||
|         model, model_height, model_width, n_classes = self.load_model(model_name) | ||||
|     def predict(self, model_in, img, use_patches): | ||||
|         model, model_height, model_width, n_classes = model_in | ||||
| 
 | ||||
|         if use_patches: | ||||
| 
 | ||||
|  | @ -194,13 +202,11 @@ class SbbBinarizer: | |||
|             raise ValueError("Must pass either a opencv2 image or an image_path") | ||||
|         if image_path is not None: | ||||
|             image = cv2.imread(image_path) | ||||
|         self.start_new_session() | ||||
|         list_of_model_files = glob('%s/*.h5' % self.model_dir) | ||||
|         img_last = 0 | ||||
|         for n, model_in in enumerate(list_of_model_files): | ||||
|             self.log.info('Predicting with model %s [%s/%s]' % (model_in, n + 1, len(list_of_model_files))) | ||||
|         for n, (model, model_file) in enumerate(zip(self.models, self.model_files)): | ||||
|             self.log.info('Predicting with model %s [%s/%s]' % (model_file, n + 1, len(self.model_files))) | ||||
| 
 | ||||
|             res = self.predict(model_in, image, use_patches) | ||||
|             res = self.predict(model, image, use_patches) | ||||
| 
 | ||||
|             img_fin = np.zeros((res.shape[0], res.shape[1], 3)) | ||||
|             res[:, :][res[:, :] == 0] = 2 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue