do_prediction*(): smooth window transitions with sigmoid…

instead of hard cut-offs between overlapping window tiles,
apply sigmoid attenuation to slide from one to the next

(apply all postprocessing in the end)
This commit is contained in:
Robert Sachunsky 2026-05-08 05:18:00 +02:00
parent cefe596f8b
commit 68a26a5c3f

View file

@ -469,25 +469,23 @@ class Eynollah:
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
margin = int(marginal_of_patch_percent * img_height_model)
window = 1 / (1 + np.exp(5.0 - 5 * np.arange(2 * margin) / margin))
width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin
img_h = img.shape[0]
img_w = img.shape[1]
if is_enhancement:
prediction = np.zeros((img_h, img_w, 3), dtype=np.uint8)
else:
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
if thresholding_for_artificial_class:
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
nxf = math.ceil(img_w / float(width_mid))
nyf = math.ceil(img_h / float(height_mid))
prediction = None
nxf = math.ceil((img_w - 2.0 * margin) / width_mid)
nyf = math.ceil((img_h - 2.0 * margin) / height_mid)
batch_i = []
batch_j = []
batch_x_u = []
batch_x_d = []
batch_x_s = []
batch_y_u = []
batch_y_d = []
batch_y_s = []
batch = 0
img_patch = np.zeros((n_batch_inference,
@ -498,83 +496,101 @@ class Eynollah:
for j in range(nyf):
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_x_u > img_w:
index_x_s = index_x_u - img_w
index_x_u = img_w
index_x_d = img_w - img_width_model
else:
index_x_s = 0
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_y_u > img_h:
index_y_s = index_y_u - img_h
index_y_u = img_h
index_y_d = img_h - img_height_model
else:
index_y_s = 0
batch_i.append(i)
batch_j.append(j)
batch_x_u.append(index_x_u)
batch_x_d.append(index_x_d)
batch_x_s.append(index_x_s)
batch_y_d.append(index_y_d)
batch_y_u.append(index_y_u)
batch_y_s.append(index_y_s)
img_patch[batch] = img[index_y_d: index_y_u,
index_x_d: index_x_u]
batch += 1
if (batch == n_batch_inference or
# last batch
i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape))
label_p_pred = model.predict(img_patch, verbose=0)
if is_enhancement:
seg = (label_p_pred * 255).astype(np.uint8)
else:
seg = np.argmax(label_p_pred, axis=3)
if thresholding_for_some_classes:
seg_mask_label(
seg, label_p_pred[:,:,:,4] > 0.03,
label=4) #
seg_mask_label(
seg, label_p_pred[:,:,:,0] > 0.25,
label=0) # bg
seg_mask_label(
seg, label_p_pred[:,:,:,3] > 0.10 & seg == 0,
label=3) # line
if thresholding_for_artificial_class:
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
if prediction is None:
# now we know the number of classes
prediction = np.zeros((img_h, img_w, label_p_pred.shape[-1]), dtype=float)
for batch in range(range):
where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
batch_x_d[batch]: batch_x_u[batch]]
iny = slice(margin if batch_j[batch] else None,
(-margin or None) if batch_j[batch] < nyf - 1 else None)
inx = slice(margin if batch_i[batch] else None,
(-margin or None) if batch_i[batch] < nxf - 1 else None)
inbox = np.index_exp[iny, inx]
prediction[where][inbox] = seg[batch][inbox]
if thresholding_for_artificial_class:
mask_artificial_class[where][inbox] = seg_art[batch][inbox]
# shorter window on last tile
part = np.index_exp[batch_y_s[batch]:,
batch_x_s[batch]:]
# normalize probability (where windows overlap)
attenuation_y = np.ones(img_height_model - batch_y_s[batch])
attenuation_x = np.ones(img_width_model - batch_x_s[batch])
if margin and batch_j[batch] > 0:
attenuation_y[:2 * margin] = window
if margin and batch_j[batch] < nyf - 1:
attenuation_y[-2 * margin:] = 1 - window
if margin and batch_i[batch] > 0:
attenuation_x[:2 * margin] = window
if margin and batch_i[batch] < nxf - 1:
attenuation_x[-2 * margin:] = 1 - window
label_p_pred[batch][part] *= attenuation_y[:, np.newaxis, np.newaxis]
label_p_pred[batch][part] *= attenuation_x[np.newaxis, :, np.newaxis]
prediction[where][part] += label_p_pred[batch][part]
batch_i = []
batch_j = []
batch_x_u = []
batch_x_d = []
batch_x_s = []
batch_y_u = []
batch_y_d = []
batch_y_s = []
batch = 0
img_patch[:] = 0
if is_enhancement:
seg = (prediction * 255).astype(np.uint8)
else:
seg = np.argmax(prediction, axis=2)
if thresholding_for_some_classes:
seg_mask_label(
seg, prediction[:, :, 4] > 0.03,
label=4) #
seg_mask_label(
seg, prediction[:, :, 0] > 0.25,
label=0) # bg
seg_mask_label(
seg, prediction[:, :, 3] > 0.10 & seg == 0,
label=3) # line
if thresholding_for_artificial_class:
seg_mask_label(prediction, mask_artificial_class,
seg_art = prediction[:, :, artificial_class] >= threshold_art_class
seg_mask_label(seg, seg_art,
label=artificial_class,
only=True,
skeletonize=True,
dilate=3)
if img_h != img_h_page or img_w != img_w_page:
prediction = resize_image(prediction, img_h_page, img_w_page)
seg = resize_image(seg, img_h_page, img_w_page)
gc.collect()
return prediction
return seg
def do_prediction_new_concept(
self, patches, img, model,
@ -630,14 +646,12 @@ class Eynollah:
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
margin = int(marginal_of_patch_percent * img_height_model)
window = 1 / (1 + np.exp(5.0 - 5 * np.arange(2 * margin) / margin))
width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin
img_h = img.shape[0]
img_w = img.shape[1]
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
confidence = np.zeros((img_h, img_w))
if thresholding_for_artificial_class:
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
prediction = None
nxf = math.ceil((img_w - 2.0 * margin) / width_mid)
nyf = math.ceil((img_h - 2.0 * margin) / height_mid)
@ -645,8 +659,10 @@ class Eynollah:
batch_j = []
batch_x_u = []
batch_x_d = []
batch_x_s = []
batch_y_u = []
batch_y_d = []
batch_y_s = []
batch = 0
img_patch = np.zeros((n_batch_inference,
img_height_model,
@ -657,20 +673,28 @@ class Eynollah:
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
if index_x_u > img_w:
index_x_s = index_x_u - img_w
index_x_u = img_w
index_x_d = img_w - img_width_model
else:
index_x_s = 0
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_y_u > img_h:
index_y_s = index_y_u - img_h
index_y_u = img_h
index_y_d = img_h - img_height_model
else:
index_y_s = 0
batch_i.append(i)
batch_j.append(j)
batch_x_u.append(index_x_u)
batch_x_d.append(index_x_d)
batch_x_s.append(index_x_s)
batch_y_d.append(index_y_d)
batch_y_u.append(index_y_u)
batch_y_s.append(index_y_s)
img_patch[batch] = img[index_y_d: index_y_u,
index_x_d: index_x_u]
@ -679,43 +703,56 @@ class Eynollah:
# last batch
i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape))
label_p_pred = model.predict(img_patch,verbose=0)
seg = np.argmax(label_p_pred, axis=3)
conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)]
if thresholding_for_artificial_class:
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
label_p_pred = model.predict(img_patch, verbose=0)
if prediction is None:
# now we know the number of classes
prediction = np.zeros((img_h, img_w, label_p_pred.shape[-1]), dtype=float)
for batch in range(batch):
where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
batch_x_d[batch]: batch_x_u[batch]]
iny = slice(margin if batch_j[batch] else None,
(-margin or None) if batch_j[batch] < nyf - 1 else None)
inx = slice(margin if batch_i[batch] else None,
(-margin or None) if batch_i[batch] < nxf - 1 else None)
inbox = np.index_exp[iny, inx]
prediction[where][inbox] = seg[batch][inbox]
confidence[where][inbox] = conf[batch][inbox]
if thresholding_for_artificial_class:
mask_artificial_class[where][inbox] = seg_art[batch][inbox]
# shorter window on last tile
part = np.index_exp[batch_y_s[batch]:,
batch_x_s[batch]:]
# normalize probability (where windows overlap)
attenuation_y = np.ones(img_height_model - batch_y_s[batch])
attenuation_x = np.ones(img_width_model - batch_x_s[batch])
if margin and batch_j[batch] > 0:
attenuation_y[:2 * margin] = window
if margin and batch_j[batch] < nyf - 1:
attenuation_y[-2 * margin:] = 1 - window
if margin and batch_i[batch] > 0:
attenuation_x[:2 * margin] = window
if margin and batch_i[batch] < nxf - 1:
attenuation_x[-2 * margin:] = 1 - window
label_p_pred[batch][part] *= attenuation_y[:, np.newaxis, np.newaxis]
label_p_pred[batch][part] *= attenuation_x[np.newaxis, :, np.newaxis]
prediction[where][part] += label_p_pred[batch][part]
batch_i = []
batch_j = []
batch_x_u = []
batch_x_d = []
batch_x_s = []
batch_y_u = []
batch_y_d = []
batch_y_s = []
batch = 0
img_patch[:] = 0
# decode
seg = np.argmax(prediction, axis=2)
conf = prediction[tuple(np.indices(seg.shape)) + (seg,)]
if thresholding_for_artificial_class:
seg_mask_label(prediction, mask_artificial_class,
seg_art = prediction[:, :, artificial_class] >= threshold_art_class
seg_mask_label(seg, seg_art,
label=artificial_class,
only=True,
skeletonize=True,
dilate=3,
keep=separator_class)
gc.collect()
return prediction, confidence
return seg, conf
# variant of do_prediction_new_concept with no need
# for resizing or tiling into patches - done on model