From 0f82b568bac06c48247e69a31f7c9486c7a2290a Mon Sep 17 00:00:00 2001 From: Robert Sachunsky Date: Thu, 16 Apr 2026 05:02:20 +0200 Subject: [PATCH] =?UTF-8?q?do=5Fprediction=5Fnew=5Fconcept:=20aggregate=20?= =?UTF-8?q?confidence=20for=20all=20classes=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (not just text; will still have to pass that on to the writer...) --- src/eynollah/eynollah.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/eynollah/eynollah.py b/src/eynollah/eynollah.py index 180a12a..ff9e8e6 100644 --- a/src/eynollah/eynollah.py +++ b/src/eynollah/eynollah.py @@ -657,8 +657,9 @@ class Eynollah: dilate=3, keep=separator_class) - conf_text = resize_image(label_p_pred[:, :, 1], img_h_page, img_w_page) - return prediction, conf_text + conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)] + conf = resize_image(conf, img_h_page, img_w_page) + return prediction, conf if img.shape[0] < img_height_model: img = resize_image(img, img_height_model, img.shape[1]) @@ -717,7 +718,7 @@ class Eynollah: self.logger.debug("predicting patches on %s", str(img_patch.shape)) label_p_pred = model.predict(img_patch,verbose=0) seg = np.argmax(label_p_pred, axis=3) - conf = label_p_pred[:, :, :, 1] + conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)] if thresholding_for_artificial_class: seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class