inference is implemented with batch size bigger than 1

binarization_flow_from_directory
vahidrezanezhad 3 months ago
parent 3095498162
commit 93cba20810

@ -59,7 +59,7 @@ class SbbBinarizer:
n_classes = model.layers[len(model.layers)-1].output_shape[3] n_classes = model.layers[len(model.layers)-1].output_shape[3]
return model, model_height, model_width, n_classes return model, model_height, model_width, n_classes
def predict(self, model_in, img, use_patches): def predict(self, model_in, img, use_patches, n_batch_inference=5):
tensorflow_backend.set_session(self.session) tensorflow_backend.set_session(self.session)
model, model_height, model_width, n_classes = model_in model, model_height, model_width, n_classes = model_in
@ -128,6 +128,18 @@ class SbbBinarizer:
nyf = int(nyf) + 1 nyf = int(nyf) + 1
else: else:
nyf = int(nyf) nyf = int(nyf)
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
img_patch = np.zeros((n_batch_inference, model_height, model_width,3))
for i in range(nxf): for i in range(nxf):
for j in range(nyf): for j in range(nyf):
@ -152,77 +164,82 @@ class SbbBinarizer:
if index_y_u > img_h: if index_y_u > img_h:
index_y_u = img_h index_y_u = img_h
index_y_d = img_h - model_height index_y_d = img_h - model_height
img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
list_i_s.append(i)
label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2])) list_j_s.append(j)
list_x_u.append(index_x_u)
seg = np.argmax(label_p_pred, axis=3)[0] list_x_d.append(index_x_d)
list_y_d.append(index_y_d)
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) list_y_u.append(index_y_u)
if i == 0 and j == 0:
seg_color = seg_color[0:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :] img_patch[batch_indexer,:,:,:] = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
seg = seg[0:seg.shape[0] - margin, 0:seg.shape[1] - margin]
batch_indexer = batch_indexer + 1
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color
elif i == nxf-1 and j == nyf-1: if batch_indexer == n_batch_inference:
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - 0, :]
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - 0] label_p_pred = model.predict(img_patch,verbose=0)
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0] = seg seg = np.argmax(label_p_pred, axis=3)
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0, :] = seg_color
#print(seg.shape, len(seg), len(list_i_s))
elif i == 0 and j == nyf-1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, 0:seg_color.shape[1] - margin, :] indexer_inside_batch = 0
seg = seg[margin:seg.shape[0] - 0, 0:seg.shape[1] - margin] for i_batch, j_batch in zip(list_i_s, list_j_s):
seg_in = seg[indexer_inside_batch,:,:]
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin] = seg seg_color = np.repeat(seg_in[:, :, np.newaxis], 3, axis=2)
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin, :] = seg_color
index_y_u_in = list_y_u[indexer_inside_batch]
elif i == nxf-1 and j == 0: index_y_d_in = list_y_d[indexer_inside_batch]
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :]
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - 0] index_x_u_in = list_x_u[indexer_inside_batch]
index_x_d_in = list_x_d[indexer_inside_batch]
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color if i_batch == 0 and j_batch == 0:
seg_color = seg_color[0 : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
elif i == 0 and j != 0 and j != nyf-1: prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
seg_color = seg_color[margin:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :] elif i_batch == nxf - 1 and j_batch == nyf - 1:
seg = seg[margin:seg.shape[0] - margin, 0:seg.shape[1] - margin] seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - 0, :]
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg elif i_batch == 0 and j_batch == nyf - 1:
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color seg_color = seg_color[margin : seg_color.shape[0] - 0, 0 : seg_color.shape[1] - margin, :]
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
elif i == nxf-1 and j != 0 and j != nyf-1: elif i_batch == nxf - 1 and j_batch == 0:
seg_color = seg_color[margin:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :] seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
seg = seg[margin:seg.shape[0] - margin, margin:seg.shape[1] - 0] prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg seg_color = seg_color[margin : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
elif i != 0 and i != nxf-1 and j == 0: seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - margin, :] prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - margin] elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - margin, :]
elif i != 0 and i != nxf-1 and j == nyf-1: prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - margin, :] else:
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - margin] seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
mask_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - margin, :] = seg_color indexer_inside_batch = indexer_inside_batch +1
else:
seg_color = seg_color[margin:seg_color.shape[0] - margin, margin:seg_color.shape[1] - margin, :] list_i_s = []
seg = seg[margin:seg.shape[0] - margin, margin:seg.shape[1] - margin] list_j_s = []
list_x_u = []
mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg list_x_d = []
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color list_y_u = []
list_y_d = []
batch_indexer = 0
img_patch = np.zeros((n_batch_inference, model_height, model_width,3))

Loading…
Cancel
Save