mirror of
https://github.com/qurator-spk/sbb_binarization.git
synced 2025-06-09 12:19:56 +02:00
resolving error for inputs whcih have smaller scale than model patch
This commit is contained in:
parent
3518618a01
commit
2c7cd84649
1 changed files with 42 additions and 1 deletions
|
@ -56,6 +56,44 @@ class SbbBinarizer:
|
|||
|
||||
def predict(self, model_in, img, use_patches):
|
||||
model, model_height, model_width, n_classes = model_in
|
||||
|
||||
img_org_h = img.shape[0]
|
||||
img_org_w = img.shape[1]
|
||||
|
||||
if img.shape[0] < model_height and img.shape[1] >= model_width:
|
||||
img_padded = np.zeros(( model_height, img.shape[1], img.shape[2] ))
|
||||
|
||||
index_start_h = int( abs( img.shape[0] - model_height) /2.)
|
||||
index_start_w = 0
|
||||
|
||||
img_padded [ index_start_h: index_start_h+img.shape[0], :, : ] = img[:,:,:]
|
||||
|
||||
elif img.shape[0] >= model_height and img.shape[1] < model_width:
|
||||
img_padded = np.zeros(( img.shape[0], model_width, img.shape[2] ))
|
||||
|
||||
index_start_h = 0
|
||||
index_start_w = int( abs( img.shape[1] - model_width) /2.)
|
||||
|
||||
img_padded [ :, index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:]
|
||||
|
||||
|
||||
elif img.shape[0] < model_height and img.shape[1] < model_width:
|
||||
img_padded = np.zeros(( model_height, model_width, img.shape[2] ))
|
||||
|
||||
index_start_h = int( abs( img.shape[0] - model_height) /2.)
|
||||
index_start_w = int( abs( img.shape[1] - model_width) /2.)
|
||||
|
||||
img_padded [ index_start_h: index_start_h+img.shape[0], index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:]
|
||||
|
||||
else:
|
||||
index_start_h = 0
|
||||
index_start_w = 0
|
||||
img_padded = np.copy(img)
|
||||
|
||||
|
||||
img = np.copy(img_padded)
|
||||
|
||||
|
||||
|
||||
if use_patches:
|
||||
|
||||
|
@ -179,7 +217,10 @@ class SbbBinarizer:
|
|||
|
||||
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
|
||||
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color
|
||||
|
||||
|
||||
|
||||
|
||||
prediction_true = prediction_true[index_start_h: index_start_h+img_org_h, index_start_w: index_start_w+img_org_w,:]
|
||||
prediction_true = prediction_true.astype(np.uint8)
|
||||
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue