diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 180a12a..ff9e8e6 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -657,8 +657,9 @@ class Eynollah: dilate=3, keep=separator_class) - conf_text = resize_image(label_p_pred[:, :, 1], img_h_page, img_w_page) - return prediction, conf_text + conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)] + conf = resize_image(conf, img_h_page, img_w_page) + return prediction, conf if img.shape[0] < img_height_model: img = resize_image(img, img_height_model, img.shape[1]) @@ -717,7 +718,7 @@ class Eynollah: self.logger.debug("predicting patches on %s", str(img_patch.shape)) label_p_pred = model.predict(img_patch,verbose=0) seg = np.argmax(label_p_pred, axis=3) - conf = label_p_pred[:, :, :, 1] + conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)] if thresholding_for_artificial_class: seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class