do_prediction*(): avoid unnecessary tiles, simplify…

- calculation for number of tiles: sometimes one less
  tile is needed by making the previous last tile
  half-full on the right side
- calculation of window margins: fix case if dimension
  extends to full image shape
- simplify (identifiers, slicing etc)
This commit is contained in:
Robert Sachunsky 2026-05-08 00:55:18 +02:00
parent d8c83d6137
commit cefe596f8b

View file

@ -482,15 +482,18 @@ class Eynollah:
nxf = math.ceil(img_w / float(width_mid)) nxf = math.ceil(img_w / float(width_mid))
nyf = math.ceil(img_h / float(height_mid)) nyf = math.ceil(img_h / float(height_mid))
list_i_s = [] batch_i = []
list_j_s = [] batch_j = []
list_x_u = [] batch_x_u = []
list_x_d = [] batch_x_d = []
list_y_u = [] batch_y_u = []
list_y_d = [] batch_y_d = []
batch_indexer = 0 batch = 0
img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3), dtype=np.float16) img_patch = np.zeros((n_batch_inference,
img_height_model,
img_width_model,
3), dtype=np.float16)
for i in range(nxf): for i in range(nxf):
for j in range(nyf): for j in range(nyf):
index_x_d = i * width_mid index_x_d = i * width_mid
@ -504,18 +507,18 @@ class Eynollah:
index_y_u = img_h index_y_u = img_h
index_y_d = img_h - img_height_model index_y_d = img_h - img_height_model
list_i_s.append(i) batch_i.append(i)
list_j_s.append(j) batch_j.append(j)
list_x_u.append(index_x_u) batch_x_u.append(index_x_u)
list_x_d.append(index_x_d) batch_x_d.append(index_x_d)
list_y_d.append(index_y_d) batch_y_d.append(index_y_d)
list_y_u.append(index_y_u) batch_y_u.append(index_y_u)
img_patch[batch_indexer] = img[index_y_d:index_y_u, img_patch[batch] = img[index_y_d: index_y_u,
index_x_d:index_x_u] index_x_d: index_x_u]
batch_indexer += 1 batch += 1
if (batch_indexer == n_batch_inference or if (batch == n_batch_inference or
# last batch # last batch
i == nxf - 1 and j == nyf - 1): i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape)) self.logger.debug("predicting patches on %s", str(img_patch.shape))
@ -539,75 +542,25 @@ class Eynollah:
if thresholding_for_artificial_class: if thresholding_for_artificial_class:
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
indexer_inside_batch = 0 for batch in range(range):
for i_batch, j_batch in zip(list_i_s, list_j_s): where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
seg_in = seg[indexer_inside_batch] batch_x_d[batch]: batch_x_u[batch]]
iny = slice(margin if batch_j[batch] else None,
(-margin or None) if batch_j[batch] < nyf - 1 else None)
inx = slice(margin if batch_i[batch] else None,
(-margin or None) if batch_i[batch] < nxf - 1 else None)
inbox = np.index_exp[iny, inx]
prediction[where][inbox] = seg[batch][inbox]
if thresholding_for_artificial_class: if thresholding_for_artificial_class:
seg_in_art = seg_art[indexer_inside_batch] mask_artificial_class[where][inbox] = seg_art[batch][inbox]
index_y_u_in = list_y_u[indexer_inside_batch] batch_i = []
index_y_d_in = list_y_d[indexer_inside_batch] batch_j = []
batch_x_u = []
index_x_u_in = list_x_u[indexer_inside_batch] batch_x_d = []
index_x_d_in = list_x_d[indexer_inside_batch] batch_y_u = []
batch_y_d = []
where = np.index_exp[index_y_d_in:index_y_u_in, batch = 0
index_x_d_in:index_x_u_in]
if (i_batch == 0 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
margin:]
elif (i_batch == 0 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
margin:]
elif (i_batch == 0 and
j_batch != 0 and
j_batch != nyf - 1):
inbox = np.index_exp[margin:-margin or None,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch != 0 and
j_batch != nyf - 1):
inbox = np.index_exp[margin:-margin or None,
margin:]
elif (i_batch != 0 and
i_batch != nxf - 1 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
margin:-margin or None]
elif (i_batch != 0 and
i_batch != nxf - 1 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
margin:-margin or None]
else:
inbox = np.index_exp[margin:-margin or None,
margin:-margin or None]
prediction[where][inbox] = seg_in[inbox]
if thresholding_for_artificial_class:
mask_artificial_class[where][inbox] = seg_in_art[inbox]
indexer_inside_batch += 1
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
img_patch[:] = 0 img_patch[:] = 0
if thresholding_for_artificial_class: if thresholding_for_artificial_class:
@ -685,123 +638,73 @@ class Eynollah:
confidence = np.zeros((img_h, img_w)) confidence = np.zeros((img_h, img_w))
if thresholding_for_artificial_class: if thresholding_for_artificial_class:
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool) mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
nxf = math.ceil(img_w / float(width_mid)) nxf = math.ceil((img_w - 2.0 * margin) / width_mid)
nyf = math.ceil(img_h / float(height_mid)) nyf = math.ceil((img_h - 2.0 * margin) / height_mid)
list_i_s = [] batch_i = []
list_j_s = [] batch_j = []
list_x_u = [] batch_x_u = []
list_x_d = [] batch_x_d = []
list_y_u = [] batch_y_u = []
list_y_d = [] batch_y_d = []
batch = 0
batch_indexer = 0 img_patch = np.zeros((n_batch_inference,
img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3), dtype=np.float16) img_height_model,
img_width_model,
3), dtype=np.float16)
for i in range(nxf): for i in range(nxf):
for j in range(nyf): for j in range(nyf):
index_x_d = i * width_mid index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model index_x_u = index_x_d + img_width_model
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_x_u > img_w: if index_x_u > img_w:
index_x_u = img_w index_x_u = img_w
index_x_d = img_w - img_width_model index_x_d = img_w - img_width_model
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_y_u > img_h: if index_y_u > img_h:
index_y_u = img_h index_y_u = img_h
index_y_d = img_h - img_height_model index_y_d = img_h - img_height_model
list_i_s.append(i) batch_i.append(i)
list_j_s.append(j) batch_j.append(j)
list_x_u.append(index_x_u) batch_x_u.append(index_x_u)
list_x_d.append(index_x_d) batch_x_d.append(index_x_d)
list_y_d.append(index_y_d) batch_y_d.append(index_y_d)
list_y_u.append(index_y_u) batch_y_u.append(index_y_u)
img_patch[batch_indexer] = img[index_y_d:index_y_u, img_patch[batch] = img[index_y_d: index_y_u,
index_x_d:index_x_u] index_x_d: index_x_u]
batch_indexer += 1 batch += 1
if (batch == n_batch_inference or
if (batch_indexer == n_batch_inference or
# last batch # last batch
i == nxf - 1 and j == nyf - 1): i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape)) self.logger.debug("predicting patches on %s", str(img_patch.shape))
label_p_pred = model.predict(img_patch,verbose=0) label_p_pred = model.predict(img_patch,verbose=0)
seg = np.argmax(label_p_pred, axis=3) seg = np.argmax(label_p_pred, axis=3)
conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)] conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)]
if thresholding_for_artificial_class: if thresholding_for_artificial_class:
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
indexer_inside_batch = 0 for batch in range(batch):
for i_batch, j_batch in zip(list_i_s, list_j_s): where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
seg_in = seg[indexer_inside_batch] batch_x_d[batch]: batch_x_u[batch]]
conf_in = conf[indexer_inside_batch] iny = slice(margin if batch_j[batch] else None,
(-margin or None) if batch_j[batch] < nyf - 1 else None)
inx = slice(margin if batch_i[batch] else None,
(-margin or None) if batch_i[batch] < nxf - 1 else None)
inbox = np.index_exp[iny, inx]
prediction[where][inbox] = seg[batch][inbox]
confidence[where][inbox] = conf[batch][inbox]
if thresholding_for_artificial_class: if thresholding_for_artificial_class:
seg_in_art = seg_art[indexer_inside_batch] mask_artificial_class[where][inbox] = seg_art[batch][inbox]
index_y_u_in = list_y_u[indexer_inside_batch] batch_i = []
index_y_d_in = list_y_d[indexer_inside_batch] batch_j = []
batch_x_u = []
index_x_u_in = list_x_u[indexer_inside_batch] batch_x_d = []
index_x_d_in = list_x_d[indexer_inside_batch] batch_y_u = []
batch_y_d = []
where = np.index_exp[index_y_d_in:index_y_u_in, batch = 0
index_x_d_in:index_x_u_in]
if (i_batch == 0 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
margin:]
elif (i_batch == 0 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
margin:]
elif (i_batch == 0 and
j_batch != 0 and
j_batch != nyf - 1):
inbox = np.index_exp[margin:-margin or None,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch != 0 and
j_batch != nyf - 1):
inbox = np.index_exp[margin:-margin or None,
margin:]
elif (i_batch != 0 and
i_batch != nxf - 1 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
margin:-margin or None]
elif (i_batch != 0 and
i_batch != nxf - 1 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
margin:-margin or None]
else:
inbox = np.index_exp[margin:-margin or None,
margin:-margin or None]
prediction[where][inbox] = seg_in[inbox]
confidence[where][inbox] = conf_in[inbox]
if thresholding_for_artificial_class:
mask_artificial_class[where][inbox] = seg_in_art[inbox]
indexer_inside_batch += 1
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
img_patch[:] = 0 img_patch[:] = 0
if thresholding_for_artificial_class: if thresholding_for_artificial_class: