do_prediction*(): smooth window transitions with sigmoid…

instead of hard cut-offs between overlapping window tiles,
apply sigmoid attenuation to slide from one to the next

(apply all postprocessing in the end)
This commit is contained in:
Robert Sachunsky 2026-05-08 05:18:00 +02:00
parent cefe596f8b
commit 68a26a5c3f

View file

@ -469,25 +469,23 @@ class Eynollah:
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model) self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
margin = int(marginal_of_patch_percent * img_height_model) margin = int(marginal_of_patch_percent * img_height_model)
window = 1 / (1 + np.exp(5.0 - 5 * np.arange(2 * margin) / margin))
width_mid = img_width_model - 2 * margin width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin height_mid = img_height_model - 2 * margin
img_h = img.shape[0] img_h = img.shape[0]
img_w = img.shape[1] img_w = img.shape[1]
if is_enhancement: prediction = None
prediction = np.zeros((img_h, img_w, 3), dtype=np.uint8) nxf = math.ceil((img_w - 2.0 * margin) / width_mid)
else: nyf = math.ceil((img_h - 2.0 * margin) / height_mid)
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
if thresholding_for_artificial_class:
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
nxf = math.ceil(img_w / float(width_mid))
nyf = math.ceil(img_h / float(height_mid))
batch_i = [] batch_i = []
batch_j = [] batch_j = []
batch_x_u = [] batch_x_u = []
batch_x_d = [] batch_x_d = []
batch_x_s = []
batch_y_u = [] batch_y_u = []
batch_y_d = [] batch_y_d = []
batch_y_s = []
batch = 0 batch = 0
img_patch = np.zeros((n_batch_inference, img_patch = np.zeros((n_batch_inference,
@ -498,83 +496,101 @@ class Eynollah:
for j in range(nyf): for j in range(nyf):
index_x_d = i * width_mid index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model index_x_u = index_x_d + img_width_model
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_x_u > img_w: if index_x_u > img_w:
index_x_s = index_x_u - img_w
index_x_u = img_w index_x_u = img_w
index_x_d = img_w - img_width_model index_x_d = img_w - img_width_model
else:
index_x_s = 0
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_y_u > img_h: if index_y_u > img_h:
index_y_s = index_y_u - img_h
index_y_u = img_h index_y_u = img_h
index_y_d = img_h - img_height_model index_y_d = img_h - img_height_model
else:
index_y_s = 0
batch_i.append(i) batch_i.append(i)
batch_j.append(j) batch_j.append(j)
batch_x_u.append(index_x_u) batch_x_u.append(index_x_u)
batch_x_d.append(index_x_d) batch_x_d.append(index_x_d)
batch_x_s.append(index_x_s)
batch_y_d.append(index_y_d) batch_y_d.append(index_y_d)
batch_y_u.append(index_y_u) batch_y_u.append(index_y_u)
batch_y_s.append(index_y_s)
img_patch[batch] = img[index_y_d: index_y_u, img_patch[batch] = img[index_y_d: index_y_u,
index_x_d: index_x_u] index_x_d: index_x_u]
batch += 1 batch += 1
if (batch == n_batch_inference or if (batch == n_batch_inference or
# last batch # last batch
i == nxf - 1 and j == nyf - 1): i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape)) self.logger.debug("predicting patches on %s", str(img_patch.shape))
label_p_pred = model.predict(img_patch, verbose=0) label_p_pred = model.predict(img_patch, verbose=0)
if is_enhancement: if prediction is None:
seg = (label_p_pred * 255).astype(np.uint8) # now we know the number of classes
else: prediction = np.zeros((img_h, img_w, label_p_pred.shape[-1]), dtype=float)
seg = np.argmax(label_p_pred, axis=3)
if thresholding_for_some_classes:
seg_mask_label(
seg, label_p_pred[:,:,:,4] > 0.03,
label=4) #
seg_mask_label(
seg, label_p_pred[:,:,:,0] > 0.25,
label=0) # bg
seg_mask_label(
seg, label_p_pred[:,:,:,3] > 0.10 & seg == 0,
label=3) # line
if thresholding_for_artificial_class:
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
for batch in range(range): for batch in range(range):
where = np.index_exp[batch_y_d[batch]: batch_y_u[batch], where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
batch_x_d[batch]: batch_x_u[batch]] batch_x_d[batch]: batch_x_u[batch]]
iny = slice(margin if batch_j[batch] else None, # shorter window on last tile
(-margin or None) if batch_j[batch] < nyf - 1 else None) part = np.index_exp[batch_y_s[batch]:,
inx = slice(margin if batch_i[batch] else None, batch_x_s[batch]:]
(-margin or None) if batch_i[batch] < nxf - 1 else None) # normalize probability (where windows overlap)
inbox = np.index_exp[iny, inx] attenuation_y = np.ones(img_height_model - batch_y_s[batch])
prediction[where][inbox] = seg[batch][inbox] attenuation_x = np.ones(img_width_model - batch_x_s[batch])
if thresholding_for_artificial_class: if margin and batch_j[batch] > 0:
mask_artificial_class[where][inbox] = seg_art[batch][inbox] attenuation_y[:2 * margin] = window
if margin and batch_j[batch] < nyf - 1:
attenuation_y[-2 * margin:] = 1 - window
if margin and batch_i[batch] > 0:
attenuation_x[:2 * margin] = window
if margin and batch_i[batch] < nxf - 1:
attenuation_x[-2 * margin:] = 1 - window
label_p_pred[batch][part] *= attenuation_y[:, np.newaxis, np.newaxis]
label_p_pred[batch][part] *= attenuation_x[np.newaxis, :, np.newaxis]
prediction[where][part] += label_p_pred[batch][part]
batch_i = [] batch_i = []
batch_j = [] batch_j = []
batch_x_u = [] batch_x_u = []
batch_x_d = [] batch_x_d = []
batch_x_s = []
batch_y_u = [] batch_y_u = []
batch_y_d = [] batch_y_d = []
batch_y_s = []
batch = 0 batch = 0
img_patch[:] = 0 img_patch[:] = 0
if is_enhancement:
seg = (prediction * 255).astype(np.uint8)
else:
seg = np.argmax(prediction, axis=2)
if thresholding_for_some_classes:
seg_mask_label(
seg, prediction[:, :, 4] > 0.03,
label=4) #
seg_mask_label(
seg, prediction[:, :, 0] > 0.25,
label=0) # bg
seg_mask_label(
seg, prediction[:, :, 3] > 0.10 & seg == 0,
label=3) # line
if thresholding_for_artificial_class: if thresholding_for_artificial_class:
seg_mask_label(prediction, mask_artificial_class, seg_art = prediction[:, :, artificial_class] >= threshold_art_class
seg_mask_label(seg, seg_art,
label=artificial_class, label=artificial_class,
only=True, only=True,
skeletonize=True, skeletonize=True,
dilate=3) dilate=3)
if img_h != img_h_page or img_w != img_w_page: if img_h != img_h_page or img_w != img_w_page:
prediction = resize_image(prediction, img_h_page, img_w_page) seg = resize_image(seg, img_h_page, img_w_page)
gc.collect() gc.collect()
return prediction return seg
def do_prediction_new_concept( def do_prediction_new_concept(
self, patches, img, model, self, patches, img, model,
@ -630,14 +646,12 @@ class Eynollah:
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model) self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
margin = int(marginal_of_patch_percent * img_height_model) margin = int(marginal_of_patch_percent * img_height_model)
window = 1 / (1 + np.exp(5.0 - 5 * np.arange(2 * margin) / margin))
width_mid = img_width_model - 2 * margin width_mid = img_width_model - 2 * margin
height_mid = img_height_model - 2 * margin height_mid = img_height_model - 2 * margin
img_h = img.shape[0] img_h = img.shape[0]
img_w = img.shape[1] img_w = img.shape[1]
prediction = np.zeros((img_h, img_w), dtype=np.uint8) prediction = None
confidence = np.zeros((img_h, img_w))
if thresholding_for_artificial_class:
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
nxf = math.ceil((img_w - 2.0 * margin) / width_mid) nxf = math.ceil((img_w - 2.0 * margin) / width_mid)
nyf = math.ceil((img_h - 2.0 * margin) / height_mid) nyf = math.ceil((img_h - 2.0 * margin) / height_mid)
@ -645,8 +659,10 @@ class Eynollah:
batch_j = [] batch_j = []
batch_x_u = [] batch_x_u = []
batch_x_d = [] batch_x_d = []
batch_x_s = []
batch_y_u = [] batch_y_u = []
batch_y_d = [] batch_y_d = []
batch_y_s = []
batch = 0 batch = 0
img_patch = np.zeros((n_batch_inference, img_patch = np.zeros((n_batch_inference,
img_height_model, img_height_model,
@ -657,20 +673,28 @@ class Eynollah:
index_x_d = i * width_mid index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model index_x_u = index_x_d + img_width_model
if index_x_u > img_w: if index_x_u > img_w:
index_x_s = index_x_u - img_w
index_x_u = img_w index_x_u = img_w
index_x_d = img_w - img_width_model index_x_d = img_w - img_width_model
else:
index_x_s = 0
index_y_d = j * height_mid index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model index_y_u = index_y_d + img_height_model
if index_y_u > img_h: if index_y_u > img_h:
index_y_s = index_y_u - img_h
index_y_u = img_h index_y_u = img_h
index_y_d = img_h - img_height_model index_y_d = img_h - img_height_model
else:
index_y_s = 0
batch_i.append(i) batch_i.append(i)
batch_j.append(j) batch_j.append(j)
batch_x_u.append(index_x_u) batch_x_u.append(index_x_u)
batch_x_d.append(index_x_d) batch_x_d.append(index_x_d)
batch_x_s.append(index_x_s)
batch_y_d.append(index_y_d) batch_y_d.append(index_y_d)
batch_y_u.append(index_y_u) batch_y_u.append(index_y_u)
batch_y_s.append(index_y_s)
img_patch[batch] = img[index_y_d: index_y_u, img_patch[batch] = img[index_y_d: index_y_u,
index_x_d: index_x_u] index_x_d: index_x_u]
@ -679,43 +703,56 @@ class Eynollah:
# last batch # last batch
i == nxf - 1 and j == nyf - 1): i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape)) self.logger.debug("predicting patches on %s", str(img_patch.shape))
label_p_pred = model.predict(img_patch,verbose=0) label_p_pred = model.predict(img_patch, verbose=0)
seg = np.argmax(label_p_pred, axis=3) if prediction is None:
conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)] # now we know the number of classes
if thresholding_for_artificial_class: prediction = np.zeros((img_h, img_w, label_p_pred.shape[-1]), dtype=float)
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
for batch in range(batch): for batch in range(batch):
where = np.index_exp[batch_y_d[batch]: batch_y_u[batch], where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
batch_x_d[batch]: batch_x_u[batch]] batch_x_d[batch]: batch_x_u[batch]]
iny = slice(margin if batch_j[batch] else None, # shorter window on last tile
(-margin or None) if batch_j[batch] < nyf - 1 else None) part = np.index_exp[batch_y_s[batch]:,
inx = slice(margin if batch_i[batch] else None, batch_x_s[batch]:]
(-margin or None) if batch_i[batch] < nxf - 1 else None) # normalize probability (where windows overlap)
inbox = np.index_exp[iny, inx] attenuation_y = np.ones(img_height_model - batch_y_s[batch])
prediction[where][inbox] = seg[batch][inbox] attenuation_x = np.ones(img_width_model - batch_x_s[batch])
confidence[where][inbox] = conf[batch][inbox] if margin and batch_j[batch] > 0:
if thresholding_for_artificial_class: attenuation_y[:2 * margin] = window
mask_artificial_class[where][inbox] = seg_art[batch][inbox] if margin and batch_j[batch] < nyf - 1:
attenuation_y[-2 * margin:] = 1 - window
if margin and batch_i[batch] > 0:
attenuation_x[:2 * margin] = window
if margin and batch_i[batch] < nxf - 1:
attenuation_x[-2 * margin:] = 1 - window
label_p_pred[batch][part] *= attenuation_y[:, np.newaxis, np.newaxis]
label_p_pred[batch][part] *= attenuation_x[np.newaxis, :, np.newaxis]
prediction[where][part] += label_p_pred[batch][part]
batch_i = [] batch_i = []
batch_j = [] batch_j = []
batch_x_u = [] batch_x_u = []
batch_x_d = [] batch_x_d = []
batch_x_s = []
batch_y_u = [] batch_y_u = []
batch_y_d = [] batch_y_d = []
batch_y_s = []
batch = 0 batch = 0
img_patch[:] = 0 img_patch[:] = 0
# decode
seg = np.argmax(prediction, axis=2)
conf = prediction[tuple(np.indices(seg.shape)) + (seg,)]
if thresholding_for_artificial_class: if thresholding_for_artificial_class:
seg_mask_label(prediction, mask_artificial_class, seg_art = prediction[:, :, artificial_class] >= threshold_art_class
seg_mask_label(seg, seg_art,
label=artificial_class, label=artificial_class,
only=True, only=True,
skeletonize=True, skeletonize=True,
dilate=3, dilate=3,
keep=separator_class) keep=separator_class)
gc.collect() gc.collect()
return prediction, confidence return seg, conf
# variant of do_prediction_new_concept with no need # variant of do_prediction_new_concept with no need
# for resizing or tiling into patches - done on model # for resizing or tiling into patches - done on model