mirror of
				https://github.com/qurator-spk/sbb_binarization.git
				synced 2025-10-31 01:24:14 +01:00 
			
		
		
		
	hybrid cnn & transformer model is integrated
This commit is contained in:
		
							parent
							
								
									e4c1eb2913
								
							
						
					
					
						commit
						ffdc776192
					
				
					 3 changed files with 101 additions and 5 deletions
				
			
		|  | @ -1,7 +1,7 @@ | |||
| """ | ||||
| sbb_binarize CLI | ||||
| """ | ||||
| 
 | ||||
| import click | ||||
| from click import command, option, argument, version_option, types | ||||
| from .sbb_binarize import SbbBinarizer | ||||
| 
 | ||||
|  |  | |||
|  | @ -17,14 +17,72 @@ sys.stderr = open(devnull, 'w') | |||
| import tensorflow as tf | ||||
| from tensorflow.keras.models import load_model | ||||
| from tensorflow.python.keras import backend as tensorflow_backend | ||||
| from tensorflow.keras import layers | ||||
| import tensorflow.keras.losses | ||||
| from tensorflow.keras.layers import * | ||||
| sys.stderr = stderr | ||||
| 
 | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| 
 | ||||
| projection_dim = 64 | ||||
| patch_size = 1 | ||||
| num_patches =14*14 | ||||
| 
 | ||||
| def resize_image(img_in, input_height, input_width): | ||||
|     return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) | ||||
| 
 | ||||
| 
 | ||||
| class Patches(layers.Layer): | ||||
|     def __init__(self, **kwargs): | ||||
|         super(Patches, self).__init__() | ||||
|         self.patch_size = patch_size | ||||
| 
 | ||||
|     def call(self, images): | ||||
|         batch_size = tf.shape(images)[0] | ||||
|         patches = tf.image.extract_patches( | ||||
|             images=images, | ||||
|             sizes=[1, self.patch_size, self.patch_size, 1], | ||||
|             strides=[1, self.patch_size, self.patch_size, 1], | ||||
|             rates=[1, 1, 1, 1], | ||||
|             padding="VALID", | ||||
|         ) | ||||
|         patch_dims = patches.shape[-1] | ||||
|         patches = tf.reshape(patches, [batch_size, -1, patch_dims]) | ||||
|         return patches | ||||
|     def get_config(self): | ||||
| 
 | ||||
|         config = super().get_config().copy() | ||||
|         config.update({ | ||||
|             'patch_size': self.patch_size, | ||||
|         }) | ||||
|         return config | ||||
|      | ||||
|      | ||||
| class PatchEncoder(layers.Layer): | ||||
|     def __init__(self, **kwargs): | ||||
|         super(PatchEncoder, self).__init__() | ||||
|         self.num_patches = num_patches | ||||
|         self.projection = layers.Dense(units=projection_dim) | ||||
|         self.position_embedding = layers.Embedding( | ||||
|             input_dim=num_patches, output_dim=projection_dim | ||||
|         ) | ||||
| 
 | ||||
|     def call(self, patch): | ||||
|         positions = tf.range(start=0, limit=self.num_patches, delta=1) | ||||
|         encoded = self.projection(patch) + self.position_embedding(positions) | ||||
|         return encoded | ||||
|     def get_config(self): | ||||
| 
 | ||||
|         config = super().get_config().copy() | ||||
|         config.update({ | ||||
|             'num_patches': self.num_patches, | ||||
|             'projection': self.projection, | ||||
|             'position_embedding': self.position_embedding, | ||||
|         }) | ||||
|         return config | ||||
| 
 | ||||
| class SbbBinarizer: | ||||
| 
 | ||||
|     def __init__(self, model_dir, logger=None): | ||||
|  | @ -52,7 +110,10 @@ class SbbBinarizer: | |||
|         del self.session | ||||
| 
 | ||||
|     def load_model(self, model_name): | ||||
|         model = load_model(join(self.model_dir, model_name), compile=False) | ||||
|         try: | ||||
|             model = load_model(join(self.model_dir, model_name), compile=False) | ||||
|         except: | ||||
|             model = load_model(join(self.model_dir, model_name) , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) | ||||
|         model_height = model.layers[len(model.layers)-1].output_shape[1] | ||||
|         model_width = model.layers[len(model.layers)-1].output_shape[2] | ||||
|         n_classes = model.layers[len(model.layers)-1].output_shape[3] | ||||
|  | @ -153,12 +214,47 @@ class SbbBinarizer: | |||
|                         index_y_d = img_h - model_height | ||||
| 
 | ||||
|                     img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] | ||||
|                      | ||||
|                     h_res = int( img_patch.shape[0]/1.05) | ||||
|                     w_res = int( img_patch.shape[1]/1.05) | ||||
|                      | ||||
|                     img_patch_resize = resize_image(img_patch, h_res, w_res) | ||||
|                      | ||||
|                     img_patch_resized_padded =np.ones((img_patch.shape[0],img_patch.shape[1],img_patch.shape[2])).astype(float)#self.do_padding() | ||||
|                      | ||||
|                     h_start=int( abs(img_patch.shape[0]-img_patch_resize.shape[0])/2. ) | ||||
|                      | ||||
|                     w_start=int( abs(img_patch.shape[1]-img_patch_resize.shape[1])/2. ) | ||||
|                      | ||||
|                     img_patch_resized_padded[h_start:h_start+img_patch_resize.shape[0],w_start:w_start+img_patch_resize.shape[1],:]=np.copy(img_patch_resize[:,:,:]) | ||||
|                      | ||||
|                     label_p_pred_padded = model.predict(img_patch_resized_padded.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) | ||||
| 
 | ||||
|                     label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) | ||||
| 
 | ||||
|                     seg = np.argmax(label_p_pred, axis=3)[0] | ||||
|                     #seg = np.argmax(label_p_pred, axis=3)[0] | ||||
| 
 | ||||
|                     seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) | ||||
|                     #label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) | ||||
| 
 | ||||
|                     seg = np.argmax(label_p_pred, axis=3)[0] | ||||
|                      | ||||
|                      | ||||
|                     seg_padded = np.argmax(label_p_pred_padded, axis=3)[0] | ||||
|                      | ||||
|                     seg_padded_take_core = seg_padded[h_start:h_start+img_patch_resize.shape[0],w_start:w_start+img_patch_resize.shape[1]] | ||||
|                      | ||||
|                     seg_padded_take_core_org_size= resize_image(seg_padded_take_core, img_patch.shape[0], img_patch.shape[1]) | ||||
|                      | ||||
|                     #print(seg_padded_take_core_org_size,'sag padded') | ||||
|                     #print(seg,'sag') | ||||
|                      | ||||
|                     seg_tot  = seg_padded_take_core_org_size+0#seg | ||||
|                      | ||||
|                     seg_tot[seg_tot>1]=1 | ||||
| 
 | ||||
|                     seg_color = np.repeat(seg_tot[:, :, np.newaxis], 3, axis=2) | ||||
| 
 | ||||
|                     #seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) | ||||
| 
 | ||||
|                     if i == 0 and j == 0: | ||||
|                         seg_color = seg_color[0:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :] | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue