diff --git a/sbb_binarize/sbb_binarize.py b/sbb_binarize/sbb_binarize.py index 52d7853..cd0970c 100644 --- a/sbb_binarize/sbb_binarize.py +++ b/sbb_binarize/sbb_binarize.py @@ -56,6 +56,44 @@ class SbbBinarizer: def predict(self, model_in, img, use_patches): model, model_height, model_width, n_classes = model_in + + img_org_h = img.shape[0] + img_org_w = img.shape[1] + + if img.shape[0] < model_height and img.shape[1] >= model_width: + img_padded = np.zeros(( model_height, img.shape[1], img.shape[2] )) + + index_start_h = int( abs( img.shape[0] - model_height) /2.) + index_start_w = 0 + + img_padded [ index_start_h: index_start_h+img.shape[0], :, : ] = img[:,:,:] + + elif img.shape[0] >= model_height and img.shape[1] < model_width: + img_padded = np.zeros(( img.shape[0], model_width, img.shape[2] )) + + index_start_h = 0 + index_start_w = int( abs( img.shape[1] - model_width) /2.) + + img_padded [ :, index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:] + + + elif img.shape[0] < model_height and img.shape[1] < model_width: + img_padded = np.zeros(( model_height, model_width, img.shape[2] )) + + index_start_h = int( abs( img.shape[0] - model_height) /2.) + index_start_w = int( abs( img.shape[1] - model_width) /2.) + + img_padded [ index_start_h: index_start_h+img.shape[0], index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:] + + else: + index_start_h = 0 + index_start_w = 0 + img_padded = np.copy(img) + + + img = np.copy(img_padded) + + if use_patches: @@ -179,7 +217,10 @@ class SbbBinarizer: mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color - + + + + prediction_true = prediction_true[index_start_h: index_start_h+img_org_h, index_start_w: index_start_w+img_org_w,:] prediction_true = prediction_true.astype(np.uint8) else: