mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-05-13 01:13:54 +02:00
do_prediction*(): smooth window transitions with sigmoid…
instead of hard cut-offs between overlapping window tiles, apply sigmoid attenuation to slide from one to the next (apply all postprocessing in the end)
This commit is contained in:
parent
cefe596f8b
commit
68a26a5c3f
1 changed files with 97 additions and 60 deletions
|
|
@ -469,25 +469,23 @@ class Eynollah:
|
|||
|
||||
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
|
||||
margin = int(marginal_of_patch_percent * img_height_model)
|
||||
window = 1 / (1 + np.exp(5.0 - 5 * np.arange(2 * margin) / margin))
|
||||
width_mid = img_width_model - 2 * margin
|
||||
height_mid = img_height_model - 2 * margin
|
||||
img_h = img.shape[0]
|
||||
img_w = img.shape[1]
|
||||
if is_enhancement:
|
||||
prediction = np.zeros((img_h, img_w, 3), dtype=np.uint8)
|
||||
else:
|
||||
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
|
||||
if thresholding_for_artificial_class:
|
||||
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
|
||||
nxf = math.ceil(img_w / float(width_mid))
|
||||
nyf = math.ceil(img_h / float(height_mid))
|
||||
prediction = None
|
||||
nxf = math.ceil((img_w - 2.0 * margin) / width_mid)
|
||||
nyf = math.ceil((img_h - 2.0 * margin) / height_mid)
|
||||
|
||||
batch_i = []
|
||||
batch_j = []
|
||||
batch_x_u = []
|
||||
batch_x_d = []
|
||||
batch_x_s = []
|
||||
batch_y_u = []
|
||||
batch_y_d = []
|
||||
batch_y_s = []
|
||||
|
||||
batch = 0
|
||||
img_patch = np.zeros((n_batch_inference,
|
||||
|
|
@ -498,83 +496,101 @@ class Eynollah:
|
|||
for j in range(nyf):
|
||||
index_x_d = i * width_mid
|
||||
index_x_u = index_x_d + img_width_model
|
||||
index_y_d = j * height_mid
|
||||
index_y_u = index_y_d + img_height_model
|
||||
if index_x_u > img_w:
|
||||
index_x_s = index_x_u - img_w
|
||||
index_x_u = img_w
|
||||
index_x_d = img_w - img_width_model
|
||||
else:
|
||||
index_x_s = 0
|
||||
index_y_d = j * height_mid
|
||||
index_y_u = index_y_d + img_height_model
|
||||
if index_y_u > img_h:
|
||||
index_y_s = index_y_u - img_h
|
||||
index_y_u = img_h
|
||||
index_y_d = img_h - img_height_model
|
||||
else:
|
||||
index_y_s = 0
|
||||
|
||||
batch_i.append(i)
|
||||
batch_j.append(j)
|
||||
batch_x_u.append(index_x_u)
|
||||
batch_x_d.append(index_x_d)
|
||||
batch_x_s.append(index_x_s)
|
||||
batch_y_d.append(index_y_d)
|
||||
batch_y_u.append(index_y_u)
|
||||
batch_y_s.append(index_y_s)
|
||||
|
||||
img_patch[batch] = img[index_y_d: index_y_u,
|
||||
index_x_d: index_x_u]
|
||||
batch += 1
|
||||
|
||||
if (batch == n_batch_inference or
|
||||
# last batch
|
||||
i == nxf - 1 and j == nyf - 1):
|
||||
self.logger.debug("predicting patches on %s", str(img_patch.shape))
|
||||
label_p_pred = model.predict(img_patch, verbose=0)
|
||||
if is_enhancement:
|
||||
seg = (label_p_pred * 255).astype(np.uint8)
|
||||
else:
|
||||
seg = np.argmax(label_p_pred, axis=3)
|
||||
|
||||
if thresholding_for_some_classes:
|
||||
seg_mask_label(
|
||||
seg, label_p_pred[:,:,:,4] > 0.03,
|
||||
label=4) #
|
||||
seg_mask_label(
|
||||
seg, label_p_pred[:,:,:,0] > 0.25,
|
||||
label=0) # bg
|
||||
seg_mask_label(
|
||||
seg, label_p_pred[:,:,:,3] > 0.10 & seg == 0,
|
||||
label=3) # line
|
||||
|
||||
if thresholding_for_artificial_class:
|
||||
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
|
||||
if prediction is None:
|
||||
# now we know the number of classes
|
||||
prediction = np.zeros((img_h, img_w, label_p_pred.shape[-1]), dtype=float)
|
||||
|
||||
for batch in range(range):
|
||||
where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
|
||||
batch_x_d[batch]: batch_x_u[batch]]
|
||||
iny = slice(margin if batch_j[batch] else None,
|
||||
(-margin or None) if batch_j[batch] < nyf - 1 else None)
|
||||
inx = slice(margin if batch_i[batch] else None,
|
||||
(-margin or None) if batch_i[batch] < nxf - 1 else None)
|
||||
inbox = np.index_exp[iny, inx]
|
||||
prediction[where][inbox] = seg[batch][inbox]
|
||||
if thresholding_for_artificial_class:
|
||||
mask_artificial_class[where][inbox] = seg_art[batch][inbox]
|
||||
# shorter window on last tile
|
||||
part = np.index_exp[batch_y_s[batch]:,
|
||||
batch_x_s[batch]:]
|
||||
# normalize probability (where windows overlap)
|
||||
attenuation_y = np.ones(img_height_model - batch_y_s[batch])
|
||||
attenuation_x = np.ones(img_width_model - batch_x_s[batch])
|
||||
if margin and batch_j[batch] > 0:
|
||||
attenuation_y[:2 * margin] = window
|
||||
if margin and batch_j[batch] < nyf - 1:
|
||||
attenuation_y[-2 * margin:] = 1 - window
|
||||
if margin and batch_i[batch] > 0:
|
||||
attenuation_x[:2 * margin] = window
|
||||
if margin and batch_i[batch] < nxf - 1:
|
||||
attenuation_x[-2 * margin:] = 1 - window
|
||||
label_p_pred[batch][part] *= attenuation_y[:, np.newaxis, np.newaxis]
|
||||
label_p_pred[batch][part] *= attenuation_x[np.newaxis, :, np.newaxis]
|
||||
prediction[where][part] += label_p_pred[batch][part]
|
||||
|
||||
batch_i = []
|
||||
batch_j = []
|
||||
batch_x_u = []
|
||||
batch_x_d = []
|
||||
batch_x_s = []
|
||||
batch_y_u = []
|
||||
batch_y_d = []
|
||||
batch_y_s = []
|
||||
batch = 0
|
||||
img_patch[:] = 0
|
||||
|
||||
if is_enhancement:
|
||||
seg = (prediction * 255).astype(np.uint8)
|
||||
else:
|
||||
seg = np.argmax(prediction, axis=2)
|
||||
if thresholding_for_some_classes:
|
||||
seg_mask_label(
|
||||
seg, prediction[:, :, 4] > 0.03,
|
||||
label=4) #
|
||||
seg_mask_label(
|
||||
seg, prediction[:, :, 0] > 0.25,
|
||||
label=0) # bg
|
||||
seg_mask_label(
|
||||
seg, prediction[:, :, 3] > 0.10 & seg == 0,
|
||||
label=3) # line
|
||||
if thresholding_for_artificial_class:
|
||||
seg_mask_label(prediction, mask_artificial_class,
|
||||
seg_art = prediction[:, :, artificial_class] >= threshold_art_class
|
||||
seg_mask_label(seg, seg_art,
|
||||
label=artificial_class,
|
||||
only=True,
|
||||
skeletonize=True,
|
||||
dilate=3)
|
||||
|
||||
if img_h != img_h_page or img_w != img_w_page:
|
||||
prediction = resize_image(prediction, img_h_page, img_w_page)
|
||||
seg = resize_image(seg, img_h_page, img_w_page)
|
||||
|
||||
gc.collect()
|
||||
return prediction
|
||||
return seg
|
||||
|
||||
def do_prediction_new_concept(
|
||||
self, patches, img, model,
|
||||
|
|
@ -630,14 +646,12 @@ class Eynollah:
|
|||
|
||||
self.logger.debug("Patch size: %sx%s", img_height_model, img_width_model)
|
||||
margin = int(marginal_of_patch_percent * img_height_model)
|
||||
window = 1 / (1 + np.exp(5.0 - 5 * np.arange(2 * margin) / margin))
|
||||
width_mid = img_width_model - 2 * margin
|
||||
height_mid = img_height_model - 2 * margin
|
||||
img_h = img.shape[0]
|
||||
img_w = img.shape[1]
|
||||
prediction = np.zeros((img_h, img_w), dtype=np.uint8)
|
||||
confidence = np.zeros((img_h, img_w))
|
||||
if thresholding_for_artificial_class:
|
||||
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
|
||||
prediction = None
|
||||
nxf = math.ceil((img_w - 2.0 * margin) / width_mid)
|
||||
nyf = math.ceil((img_h - 2.0 * margin) / height_mid)
|
||||
|
||||
|
|
@ -645,8 +659,10 @@ class Eynollah:
|
|||
batch_j = []
|
||||
batch_x_u = []
|
||||
batch_x_d = []
|
||||
batch_x_s = []
|
||||
batch_y_u = []
|
||||
batch_y_d = []
|
||||
batch_y_s = []
|
||||
batch = 0
|
||||
img_patch = np.zeros((n_batch_inference,
|
||||
img_height_model,
|
||||
|
|
@ -657,20 +673,28 @@ class Eynollah:
|
|||
index_x_d = i * width_mid
|
||||
index_x_u = index_x_d + img_width_model
|
||||
if index_x_u > img_w:
|
||||
index_x_s = index_x_u - img_w
|
||||
index_x_u = img_w
|
||||
index_x_d = img_w - img_width_model
|
||||
else:
|
||||
index_x_s = 0
|
||||
index_y_d = j * height_mid
|
||||
index_y_u = index_y_d + img_height_model
|
||||
if index_y_u > img_h:
|
||||
index_y_s = index_y_u - img_h
|
||||
index_y_u = img_h
|
||||
index_y_d = img_h - img_height_model
|
||||
else:
|
||||
index_y_s = 0
|
||||
|
||||
batch_i.append(i)
|
||||
batch_j.append(j)
|
||||
batch_x_u.append(index_x_u)
|
||||
batch_x_d.append(index_x_d)
|
||||
batch_x_s.append(index_x_s)
|
||||
batch_y_d.append(index_y_d)
|
||||
batch_y_u.append(index_y_u)
|
||||
batch_y_s.append(index_y_s)
|
||||
|
||||
img_patch[batch] = img[index_y_d: index_y_u,
|
||||
index_x_d: index_x_u]
|
||||
|
|
@ -679,43 +703,56 @@ class Eynollah:
|
|||
# last batch
|
||||
i == nxf - 1 and j == nyf - 1):
|
||||
self.logger.debug("predicting patches on %s", str(img_patch.shape))
|
||||
label_p_pred = model.predict(img_patch,verbose=0)
|
||||
seg = np.argmax(label_p_pred, axis=3)
|
||||
conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)]
|
||||
if thresholding_for_artificial_class:
|
||||
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
|
||||
label_p_pred = model.predict(img_patch, verbose=0)
|
||||
if prediction is None:
|
||||
# now we know the number of classes
|
||||
prediction = np.zeros((img_h, img_w, label_p_pred.shape[-1]), dtype=float)
|
||||
|
||||
for batch in range(batch):
|
||||
where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
|
||||
batch_x_d[batch]: batch_x_u[batch]]
|
||||
iny = slice(margin if batch_j[batch] else None,
|
||||
(-margin or None) if batch_j[batch] < nyf - 1 else None)
|
||||
inx = slice(margin if batch_i[batch] else None,
|
||||
(-margin or None) if batch_i[batch] < nxf - 1 else None)
|
||||
inbox = np.index_exp[iny, inx]
|
||||
prediction[where][inbox] = seg[batch][inbox]
|
||||
confidence[where][inbox] = conf[batch][inbox]
|
||||
if thresholding_for_artificial_class:
|
||||
mask_artificial_class[where][inbox] = seg_art[batch][inbox]
|
||||
# shorter window on last tile
|
||||
part = np.index_exp[batch_y_s[batch]:,
|
||||
batch_x_s[batch]:]
|
||||
# normalize probability (where windows overlap)
|
||||
attenuation_y = np.ones(img_height_model - batch_y_s[batch])
|
||||
attenuation_x = np.ones(img_width_model - batch_x_s[batch])
|
||||
if margin and batch_j[batch] > 0:
|
||||
attenuation_y[:2 * margin] = window
|
||||
if margin and batch_j[batch] < nyf - 1:
|
||||
attenuation_y[-2 * margin:] = 1 - window
|
||||
if margin and batch_i[batch] > 0:
|
||||
attenuation_x[:2 * margin] = window
|
||||
if margin and batch_i[batch] < nxf - 1:
|
||||
attenuation_x[-2 * margin:] = 1 - window
|
||||
label_p_pred[batch][part] *= attenuation_y[:, np.newaxis, np.newaxis]
|
||||
label_p_pred[batch][part] *= attenuation_x[np.newaxis, :, np.newaxis]
|
||||
prediction[where][part] += label_p_pred[batch][part]
|
||||
|
||||
batch_i = []
|
||||
batch_j = []
|
||||
batch_x_u = []
|
||||
batch_x_d = []
|
||||
batch_x_s = []
|
||||
batch_y_u = []
|
||||
batch_y_d = []
|
||||
batch_y_s = []
|
||||
batch = 0
|
||||
img_patch[:] = 0
|
||||
|
||||
# decode
|
||||
seg = np.argmax(prediction, axis=2)
|
||||
conf = prediction[tuple(np.indices(seg.shape)) + (seg,)]
|
||||
if thresholding_for_artificial_class:
|
||||
seg_mask_label(prediction, mask_artificial_class,
|
||||
seg_art = prediction[:, :, artificial_class] >= threshold_art_class
|
||||
seg_mask_label(seg, seg_art,
|
||||
label=artificial_class,
|
||||
only=True,
|
||||
skeletonize=True,
|
||||
dilate=3,
|
||||
keep=separator_class)
|
||||
gc.collect()
|
||||
return prediction, confidence
|
||||
return seg, conf
|
||||
|
||||
# variant of do_prediction_new_concept with no need
|
||||
# for resizing or tiling into patches - done on model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue