mirror of
https://github.com/qurator-spk/sbb_binarization.git
synced 2025-06-07 19:35:04 +02:00
resolving error for inputs whcih have smaller scale than model patch
This commit is contained in:
parent
3518618a01
commit
2c7cd84649
1 changed files with 42 additions and 1 deletions
|
@ -57,6 +57,44 @@ class SbbBinarizer:
|
||||||
def predict(self, model_in, img, use_patches):
|
def predict(self, model_in, img, use_patches):
|
||||||
model, model_height, model_width, n_classes = model_in
|
model, model_height, model_width, n_classes = model_in
|
||||||
|
|
||||||
|
img_org_h = img.shape[0]
|
||||||
|
img_org_w = img.shape[1]
|
||||||
|
|
||||||
|
if img.shape[0] < model_height and img.shape[1] >= model_width:
|
||||||
|
img_padded = np.zeros(( model_height, img.shape[1], img.shape[2] ))
|
||||||
|
|
||||||
|
index_start_h = int( abs( img.shape[0] - model_height) /2.)
|
||||||
|
index_start_w = 0
|
||||||
|
|
||||||
|
img_padded [ index_start_h: index_start_h+img.shape[0], :, : ] = img[:,:,:]
|
||||||
|
|
||||||
|
elif img.shape[0] >= model_height and img.shape[1] < model_width:
|
||||||
|
img_padded = np.zeros(( img.shape[0], model_width, img.shape[2] ))
|
||||||
|
|
||||||
|
index_start_h = 0
|
||||||
|
index_start_w = int( abs( img.shape[1] - model_width) /2.)
|
||||||
|
|
||||||
|
img_padded [ :, index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:]
|
||||||
|
|
||||||
|
|
||||||
|
elif img.shape[0] < model_height and img.shape[1] < model_width:
|
||||||
|
img_padded = np.zeros(( model_height, model_width, img.shape[2] ))
|
||||||
|
|
||||||
|
index_start_h = int( abs( img.shape[0] - model_height) /2.)
|
||||||
|
index_start_w = int( abs( img.shape[1] - model_width) /2.)
|
||||||
|
|
||||||
|
img_padded [ index_start_h: index_start_h+img.shape[0], index_start_w: index_start_w+img.shape[1], : ] = img[:,:,:]
|
||||||
|
|
||||||
|
else:
|
||||||
|
index_start_h = 0
|
||||||
|
index_start_w = 0
|
||||||
|
img_padded = np.copy(img)
|
||||||
|
|
||||||
|
|
||||||
|
img = np.copy(img_padded)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if use_patches:
|
if use_patches:
|
||||||
|
|
||||||
margin = int(0.1 * model_width)
|
margin = int(0.1 * model_width)
|
||||||
|
@ -180,6 +218,9 @@ class SbbBinarizer:
|
||||||
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
|
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
|
||||||
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color
|
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
prediction_true = prediction_true[index_start_h: index_start_h+img_org_h, index_start_w: index_start_w+img_org_w,:]
|
||||||
prediction_true = prediction_true.astype(np.uint8)
|
prediction_true = prediction_true.astype(np.uint8)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue