diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 1fc4ce0..643c83b 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -469,25 +469,23 @@ class Eynollah: self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model) margin = int(marginal_of_patch_percent * img_height_model) + window = 1 / (1 + np.exp(5.0 - 5 * np.arange(2 * margin) / margin)) width_mid = img_width_model - 2 * margin height_mid = img_height_model - 2 * margin img_h = img.shape[0] img_w = img.shape[1] - if is_enhancement: - prediction = np.zeros((img_h, img_w, 3), dtype=np.uint8) - else: - prediction = np.zeros((img_h, img_w), dtype=np.uint8) - if thresholding_for_artificial_class: - mask_artificial_class = np.zeros((img_h, img_w), dtype=bool) - nxf = math.ceil(img_w / float(width_mid)) - nyf = math.ceil(img_h / float(height_mid)) + prediction = None + nxf = math.ceil((img_w - 2.0 * margin) / width_mid) + nyf = math.ceil((img_h - 2.0 * margin) / height_mid) batch_i = [] batch_j = [] batch_x_u = [] batch_x_d = [] + batch_x_s = [] batch_y_u = [] batch_y_d = [] + batch_y_s = [] batch = 0 img_patch = np.zeros((n_batch_inference, @@ -498,83 +496,101 @@ class Eynollah: for j in range(nyf): index_x_d = i * width_mid index_x_u = index_x_d + img_width_model - index_y_d = j * height_mid - index_y_u = index_y_d + img_height_model if index_x_u > img_w: + index_x_s = index_x_u - img_w index_x_u = img_w index_x_d = img_w - img_width_model + else: + index_x_s = 0 + index_y_d = j * height_mid + index_y_u = index_y_d + img_height_model if index_y_u > img_h: + index_y_s = index_y_u - img_h index_y_u = img_h index_y_d = img_h - img_height_model + else: + index_y_s = 0 batch_i.append(i) batch_j.append(j) batch_x_u.append(index_x_u) batch_x_d.append(index_x_d) + batch_x_s.append(index_x_s) batch_y_d.append(index_y_d) batch_y_u.append(index_y_u) + batch_y_s.append(index_y_s) img_patch[batch] = img[index_y_d: index_y_u, index_x_d: index_x_u] batch += 1 - if (batch == n_batch_inference or # last batch i == nxf - 1 and j == nyf - 1): self.logger.debug("predicting patches on %s", str(img_patch.shape)) label_p_pred = model.predict(img_patch, verbose=0) - if is_enhancement: - seg = (label_p_pred * 255).astype(np.uint8) - else: - seg = np.argmax(label_p_pred, axis=3) - - if thresholding_for_some_classes: - seg_mask_label( - seg, label_p_pred[:,:,:,4] > 0.03, - label=4) # - seg_mask_label( - seg, label_p_pred[:,:,:,0] > 0.25, - label=0) # bg - seg_mask_label( - seg, label_p_pred[:,:,:,3] > 0.10 & seg == 0, - label=3) # line - - if thresholding_for_artificial_class: - seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class + if prediction is None: + # now we know the number of classes + prediction = np.zeros((img_h, img_w, label_p_pred.shape[-1]), dtype=float) for batch in range(range): where = np.index_exp[batch_y_d[batch]: batch_y_u[batch], batch_x_d[batch]: batch_x_u[batch]] - iny = slice(margin if batch_j[batch] else None, - (-margin or None) if batch_j[batch] < nyf - 1 else None) - inx = slice(margin if batch_i[batch] else None, - (-margin or None) if batch_i[batch] < nxf - 1 else None) - inbox = np.index_exp[iny, inx] - prediction[where][inbox] = seg[batch][inbox] - if thresholding_for_artificial_class: - mask_artificial_class[where][inbox] = seg_art[batch][inbox] + # shorter window on last tile + part = np.index_exp[batch_y_s[batch]:, + batch_x_s[batch]:] + # normalize probability (where windows overlap) + attenuation_y = np.ones(img_height_model - batch_y_s[batch]) + attenuation_x = np.ones(img_width_model - batch_x_s[batch]) + if margin and batch_j[batch] > 0: + attenuation_y[:2 * margin] = window + if margin and batch_j[batch] < nyf - 1: + attenuation_y[-2 * margin:] = 1 - window + if margin and batch_i[batch] > 0: + attenuation_x[:2 * margin] = window + if margin and batch_i[batch] < nxf - 1: + attenuation_x[-2 * margin:] = 1 - window + label_p_pred[batch][part] *= attenuation_y[:, np.newaxis, np.newaxis] + label_p_pred[batch][part] *= attenuation_x[np.newaxis, :, np.newaxis] + prediction[where][part] += label_p_pred[batch][part] batch_i = [] batch_j = [] batch_x_u = [] batch_x_d = [] + batch_x_s = [] batch_y_u = [] batch_y_d = [] + batch_y_s = [] batch = 0 img_patch[:] = 0 + if is_enhancement: + seg = (prediction * 255).astype(np.uint8) + else: + seg = np.argmax(prediction, axis=2) + if thresholding_for_some_classes: + seg_mask_label( + seg, prediction[:, :, 4] > 0.03, + label=4) # + seg_mask_label( + seg, prediction[:, :, 0] > 0.25, + label=0) # bg + seg_mask_label( + seg, prediction[:, :, 3] > 0.10 & seg == 0, + label=3) # line if thresholding_for_artificial_class: - seg_mask_label(prediction, mask_artificial_class, + seg_art = prediction[:, :, artificial_class] >= threshold_art_class + seg_mask_label(seg, seg_art, label=artificial_class, only=True, skeletonize=True, dilate=3) if img_h != img_h_page or img_w != img_w_page: - prediction = resize_image(prediction, img_h_page, img_w_page) + seg = resize_image(seg, img_h_page, img_w_page) gc.collect() - return prediction + return seg def do_prediction_new_concept( self, patches, img, model, @@ -630,14 +646,12 @@ class Eynollah: self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model) margin = int(marginal_of_patch_percent * img_height_model) + window = 1 / (1 + np.exp(5.0 - 5 * np.arange(2 * margin) / margin)) width_mid = img_width_model - 2 * margin height_mid = img_height_model - 2 * margin img_h = img.shape[0] img_w = img.shape[1] - prediction = np.zeros((img_h, img_w), dtype=np.uint8) - confidence = np.zeros((img_h, img_w)) - if thresholding_for_artificial_class: - mask_artificial_class = np.zeros((img_h, img_w), dtype=bool) + prediction = None nxf = math.ceil((img_w - 2.0 * margin) / width_mid) nyf = math.ceil((img_h - 2.0 * margin) / height_mid) @@ -645,8 +659,10 @@ class Eynollah: batch_j = [] batch_x_u = [] batch_x_d = [] + batch_x_s = [] batch_y_u = [] batch_y_d = [] + batch_y_s = [] batch = 0 img_patch = np.zeros((n_batch_inference, img_height_model, @@ -657,20 +673,28 @@ class Eynollah: index_x_d = i * width_mid index_x_u = index_x_d + img_width_model if index_x_u > img_w: + index_x_s = index_x_u - img_w index_x_u = img_w index_x_d = img_w - img_width_model + else: + index_x_s = 0 index_y_d = j * height_mid index_y_u = index_y_d + img_height_model if index_y_u > img_h: + index_y_s = index_y_u - img_h index_y_u = img_h index_y_d = img_h - img_height_model + else: + index_y_s = 0 batch_i.append(i) batch_j.append(j) batch_x_u.append(index_x_u) batch_x_d.append(index_x_d) + batch_x_s.append(index_x_s) batch_y_d.append(index_y_d) batch_y_u.append(index_y_u) + batch_y_s.append(index_y_s) img_patch[batch] = img[index_y_d: index_y_u, index_x_d: index_x_u] @@ -679,43 +703,56 @@ class Eynollah: # last batch i == nxf - 1 and j == nyf - 1): self.logger.debug("predicting patches on %s", str(img_patch.shape)) - label_p_pred = model.predict(img_patch,verbose=0) - seg = np.argmax(label_p_pred, axis=3) - conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)] - if thresholding_for_artificial_class: - seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class + label_p_pred = model.predict(img_patch, verbose=0) + if prediction is None: + # now we know the number of classes + prediction = np.zeros((img_h, img_w, label_p_pred.shape[-1]), dtype=float) for batch in range(batch): where = np.index_exp[batch_y_d[batch]: batch_y_u[batch], batch_x_d[batch]: batch_x_u[batch]] - iny = slice(margin if batch_j[batch] else None, - (-margin or None) if batch_j[batch] < nyf - 1 else None) - inx = slice(margin if batch_i[batch] else None, - (-margin or None) if batch_i[batch] < nxf - 1 else None) - inbox = np.index_exp[iny, inx] - prediction[where][inbox] = seg[batch][inbox] - confidence[where][inbox] = conf[batch][inbox] - if thresholding_for_artificial_class: - mask_artificial_class[where][inbox] = seg_art[batch][inbox] + # shorter window on last tile + part = np.index_exp[batch_y_s[batch]:, + batch_x_s[batch]:] + # normalize probability (where windows overlap) + attenuation_y = np.ones(img_height_model - batch_y_s[batch]) + attenuation_x = np.ones(img_width_model - batch_x_s[batch]) + if margin and batch_j[batch] > 0: + attenuation_y[:2 * margin] = window + if margin and batch_j[batch] < nyf - 1: + attenuation_y[-2 * margin:] = 1 - window + if margin and batch_i[batch] > 0: + attenuation_x[:2 * margin] = window + if margin and batch_i[batch] < nxf - 1: + attenuation_x[-2 * margin:] = 1 - window + label_p_pred[batch][part] *= attenuation_y[:, np.newaxis, np.newaxis] + label_p_pred[batch][part] *= attenuation_x[np.newaxis, :, np.newaxis] + prediction[where][part] += label_p_pred[batch][part] batch_i = [] batch_j = [] batch_x_u = [] batch_x_d = [] + batch_x_s = [] batch_y_u = [] batch_y_d = [] + batch_y_s = [] batch = 0 img_patch[:] = 0 + # decode + seg = np.argmax(prediction, axis=2) + conf = prediction[tuple(np.indices(seg.shape)) + (seg,)] if thresholding_for_artificial_class: - seg_mask_label(prediction, mask_artificial_class, + seg_art = prediction[:, :, artificial_class] >= threshold_art_class + seg_mask_label(seg, seg_art, label=artificial_class, only=True, skeletonize=True, dilate=3, keep=separator_class) gc.collect() - return prediction, confidence + return seg, conf # variant of do_prediction_new_concept with no need # for resizing or tiling into patches - done on model