mirror of
				https://github.com/qurator-spk/sbb_binarization.git
				synced 2025-10-31 01:24:14 +01:00 
			
		
		
		
	hybrid cnn & transformer model is integrated
This commit is contained in:
		
							parent
							
								
									e4c1eb2913
								
							
						
					
					
						commit
						ffdc776192
					
				
					 3 changed files with 101 additions and 5 deletions
				
			
		|  | @ -2,4 +2,4 @@ numpy | ||||||
| setuptools >= 41 | setuptools >= 41 | ||||||
| opencv-python-headless | opencv-python-headless | ||||||
| ocrd >= 2.22.3 | ocrd >= 2.22.3 | ||||||
| tensorflow >= 2.4.0 | tensorflow == 2.4.* | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| """ | """ | ||||||
| sbb_binarize CLI | sbb_binarize CLI | ||||||
| """ | """ | ||||||
| 
 | import click | ||||||
| from click import command, option, argument, version_option, types | from click import command, option, argument, version_option, types | ||||||
| from .sbb_binarize import SbbBinarizer | from .sbb_binarize import SbbBinarizer | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -17,14 +17,72 @@ sys.stderr = open(devnull, 'w') | ||||||
| import tensorflow as tf | import tensorflow as tf | ||||||
| from tensorflow.keras.models import load_model | from tensorflow.keras.models import load_model | ||||||
| from tensorflow.python.keras import backend as tensorflow_backend | from tensorflow.python.keras import backend as tensorflow_backend | ||||||
|  | from tensorflow.keras import layers | ||||||
|  | import tensorflow.keras.losses | ||||||
|  | from tensorflow.keras.layers import * | ||||||
| sys.stderr = stderr | sys.stderr = stderr | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| import logging | import logging | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | projection_dim = 64 | ||||||
|  | patch_size = 1 | ||||||
|  | num_patches =14*14 | ||||||
|  | 
 | ||||||
| def resize_image(img_in, input_height, input_width): | def resize_image(img_in, input_height, input_width): | ||||||
|     return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) |     return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | class Patches(layers.Layer): | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         super(Patches, self).__init__() | ||||||
|  |         self.patch_size = patch_size | ||||||
|  | 
 | ||||||
|  |     def call(self, images): | ||||||
|  |         batch_size = tf.shape(images)[0] | ||||||
|  |         patches = tf.image.extract_patches( | ||||||
|  |             images=images, | ||||||
|  |             sizes=[1, self.patch_size, self.patch_size, 1], | ||||||
|  |             strides=[1, self.patch_size, self.patch_size, 1], | ||||||
|  |             rates=[1, 1, 1, 1], | ||||||
|  |             padding="VALID", | ||||||
|  |         ) | ||||||
|  |         patch_dims = patches.shape[-1] | ||||||
|  |         patches = tf.reshape(patches, [batch_size, -1, patch_dims]) | ||||||
|  |         return patches | ||||||
|  |     def get_config(self): | ||||||
|  | 
 | ||||||
|  |         config = super().get_config().copy() | ||||||
|  |         config.update({ | ||||||
|  |             'patch_size': self.patch_size, | ||||||
|  |         }) | ||||||
|  |         return config | ||||||
|  |      | ||||||
|  |      | ||||||
|  | class PatchEncoder(layers.Layer): | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         super(PatchEncoder, self).__init__() | ||||||
|  |         self.num_patches = num_patches | ||||||
|  |         self.projection = layers.Dense(units=projection_dim) | ||||||
|  |         self.position_embedding = layers.Embedding( | ||||||
|  |             input_dim=num_patches, output_dim=projection_dim | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def call(self, patch): | ||||||
|  |         positions = tf.range(start=0, limit=self.num_patches, delta=1) | ||||||
|  |         encoded = self.projection(patch) + self.position_embedding(positions) | ||||||
|  |         return encoded | ||||||
|  |     def get_config(self): | ||||||
|  | 
 | ||||||
|  |         config = super().get_config().copy() | ||||||
|  |         config.update({ | ||||||
|  |             'num_patches': self.num_patches, | ||||||
|  |             'projection': self.projection, | ||||||
|  |             'position_embedding': self.position_embedding, | ||||||
|  |         }) | ||||||
|  |         return config | ||||||
|  | 
 | ||||||
| class SbbBinarizer: | class SbbBinarizer: | ||||||
| 
 | 
 | ||||||
|     def __init__(self, model_dir, logger=None): |     def __init__(self, model_dir, logger=None): | ||||||
|  | @ -52,7 +110,10 @@ class SbbBinarizer: | ||||||
|         del self.session |         del self.session | ||||||
| 
 | 
 | ||||||
|     def load_model(self, model_name): |     def load_model(self, model_name): | ||||||
|  |         try: | ||||||
|             model = load_model(join(self.model_dir, model_name), compile=False) |             model = load_model(join(self.model_dir, model_name), compile=False) | ||||||
|  |         except: | ||||||
|  |             model = load_model(join(self.model_dir, model_name) , compile=False,custom_objects = {"PatchEncoder": PatchEncoder, "Patches": Patches}) | ||||||
|         model_height = model.layers[len(model.layers)-1].output_shape[1] |         model_height = model.layers[len(model.layers)-1].output_shape[1] | ||||||
|         model_width = model.layers[len(model.layers)-1].output_shape[2] |         model_width = model.layers[len(model.layers)-1].output_shape[2] | ||||||
|         n_classes = model.layers[len(model.layers)-1].output_shape[3] |         n_classes = model.layers[len(model.layers)-1].output_shape[3] | ||||||
|  | @ -154,11 +215,46 @@ class SbbBinarizer: | ||||||
| 
 | 
 | ||||||
|                     img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] |                     img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] | ||||||
|                      |                      | ||||||
|  |                     h_res = int( img_patch.shape[0]/1.05) | ||||||
|  |                     w_res = int( img_patch.shape[1]/1.05) | ||||||
|  |                      | ||||||
|  |                     img_patch_resize = resize_image(img_patch, h_res, w_res) | ||||||
|  |                      | ||||||
|  |                     img_patch_resized_padded =np.ones((img_patch.shape[0],img_patch.shape[1],img_patch.shape[2])).astype(float)#self.do_padding() | ||||||
|  |                      | ||||||
|  |                     h_start=int( abs(img_patch.shape[0]-img_patch_resize.shape[0])/2. ) | ||||||
|  |                      | ||||||
|  |                     w_start=int( abs(img_patch.shape[1]-img_patch_resize.shape[1])/2. ) | ||||||
|  |                      | ||||||
|  |                     img_patch_resized_padded[h_start:h_start+img_patch_resize.shape[0],w_start:w_start+img_patch_resize.shape[1],:]=np.copy(img_patch_resize[:,:,:]) | ||||||
|  |                      | ||||||
|  |                     label_p_pred_padded = model.predict(img_patch_resized_padded.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) | ||||||
|  | 
 | ||||||
|                     label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) |                     label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) | ||||||
| 
 | 
 | ||||||
|  |                     #seg = np.argmax(label_p_pred, axis=3)[0] | ||||||
|  | 
 | ||||||
|  |                     #label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) | ||||||
|  | 
 | ||||||
|                     seg = np.argmax(label_p_pred, axis=3)[0] |                     seg = np.argmax(label_p_pred, axis=3)[0] | ||||||
|                      |                      | ||||||
|                     seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) |                      | ||||||
|  |                     seg_padded = np.argmax(label_p_pred_padded, axis=3)[0] | ||||||
|  |                      | ||||||
|  |                     seg_padded_take_core = seg_padded[h_start:h_start+img_patch_resize.shape[0],w_start:w_start+img_patch_resize.shape[1]] | ||||||
|  |                      | ||||||
|  |                     seg_padded_take_core_org_size= resize_image(seg_padded_take_core, img_patch.shape[0], img_patch.shape[1]) | ||||||
|  |                      | ||||||
|  |                     #print(seg_padded_take_core_org_size,'sag padded') | ||||||
|  |                     #print(seg,'sag') | ||||||
|  |                      | ||||||
|  |                     seg_tot  = seg_padded_take_core_org_size+0#seg | ||||||
|  |                      | ||||||
|  |                     seg_tot[seg_tot>1]=1 | ||||||
|  | 
 | ||||||
|  |                     seg_color = np.repeat(seg_tot[:, :, np.newaxis], 3, axis=2) | ||||||
|  | 
 | ||||||
|  |                     #seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) | ||||||
| 
 | 
 | ||||||
|                     if i == 0 and j == 0: |                     if i == 0 and j == 0: | ||||||
|                         seg_color = seg_color[0:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :] |                         seg_color = seg_color[0:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :] | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue