do_prediction*(): avoid unnecessary tiles, simplify…

- calculation for number of tiles: sometimes one less
  tile is needed by making the previous last tile
  half-full on the right side
- calculation of window margins: fix case if dimension
  extends to full image shape
- simplify (identifiers, slicing etc)
This commit is contained in:
Robert Sachunsky 2026-05-08 00:55:18 +02:00
parent d8c83d6137
commit cefe596f8b

View file

@ -482,15 +482,18 @@ class Eynollah:
nxf = math.ceil(img_w / float(width_mid))
nyf = math.ceil(img_h / float(height_mid))
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_i = []
batch_j = []
batch_x_u = []
batch_x_d = []
batch_y_u = []
batch_y_d = []
batch_indexer = 0
img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3), dtype=np.float16)
batch = 0
img_patch = np.zeros((n_batch_inference,
img_height_model,
img_width_model,
3), dtype=np.float16)
for i in range(nxf):
for j in range(nyf):
index_x_d = i * width_mid
@ -504,18 +507,18 @@ class Eynollah:
index_y_u = img_h
index_y_d = img_h - img_height_model
list_i_s.append(i)
list_j_s.append(j)
list_x_u.append(index_x_u)
list_x_d.append(index_x_d)
list_y_d.append(index_y_d)
list_y_u.append(index_y_u)
batch_i.append(i)
batch_j.append(j)
batch_x_u.append(index_x_u)
batch_x_d.append(index_x_d)
batch_y_d.append(index_y_d)
batch_y_u.append(index_y_u)
img_patch[batch_indexer] = img[index_y_d:index_y_u,
index_x_d:index_x_u]
batch_indexer += 1
img_patch[batch] = img[index_y_d: index_y_u,
index_x_d: index_x_u]
batch += 1
if (batch_indexer == n_batch_inference or
if (batch == n_batch_inference or
# last batch
i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape))
@ -539,75 +542,25 @@ class Eynollah:
if thresholding_for_artificial_class:
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
indexer_inside_batch = 0
for i_batch, j_batch in zip(list_i_s, list_j_s):
seg_in = seg[indexer_inside_batch]
for batch in range(range):
where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
batch_x_d[batch]: batch_x_u[batch]]
iny = slice(margin if batch_j[batch] else None,
(-margin or None) if batch_j[batch] < nyf - 1 else None)
inx = slice(margin if batch_i[batch] else None,
(-margin or None) if batch_i[batch] < nxf - 1 else None)
inbox = np.index_exp[iny, inx]
prediction[where][inbox] = seg[batch][inbox]
if thresholding_for_artificial_class:
seg_in_art = seg_art[indexer_inside_batch]
mask_artificial_class[where][inbox] = seg_art[batch][inbox]
index_y_u_in = list_y_u[indexer_inside_batch]
index_y_d_in = list_y_d[indexer_inside_batch]
index_x_u_in = list_x_u[indexer_inside_batch]
index_x_d_in = list_x_d[indexer_inside_batch]
where = np.index_exp[index_y_d_in:index_y_u_in,
index_x_d_in:index_x_u_in]
if (i_batch == 0 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
margin:]
elif (i_batch == 0 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
margin:]
elif (i_batch == 0 and
j_batch != 0 and
j_batch != nyf - 1):
inbox = np.index_exp[margin:-margin or None,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch != 0 and
j_batch != nyf - 1):
inbox = np.index_exp[margin:-margin or None,
margin:]
elif (i_batch != 0 and
i_batch != nxf - 1 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
margin:-margin or None]
elif (i_batch != 0 and
i_batch != nxf - 1 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
margin:-margin or None]
else:
inbox = np.index_exp[margin:-margin or None,
margin:-margin or None]
prediction[where][inbox] = seg_in[inbox]
if thresholding_for_artificial_class:
mask_artificial_class[where][inbox] = seg_in_art[inbox]
indexer_inside_batch += 1
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
batch_i = []
batch_j = []
batch_x_u = []
batch_x_d = []
batch_y_u = []
batch_y_d = []
batch = 0
img_patch[:] = 0
if thresholding_for_artificial_class:
@ -685,123 +638,73 @@ class Eynollah:
confidence = np.zeros((img_h, img_w))
if thresholding_for_artificial_class:
mask_artificial_class = np.zeros((img_h, img_w), dtype=bool)
nxf = math.ceil(img_w / float(width_mid))
nyf = math.ceil(img_h / float(height_mid))
nxf = math.ceil((img_w - 2.0 * margin) / width_mid)
nyf = math.ceil((img_h - 2.0 * margin) / height_mid)
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3), dtype=np.float16)
batch_i = []
batch_j = []
batch_x_u = []
batch_x_d = []
batch_y_u = []
batch_y_d = []
batch = 0
img_patch = np.zeros((n_batch_inference,
img_height_model,
img_width_model,
3), dtype=np.float16)
for i in range(nxf):
for j in range(nyf):
index_x_d = i * width_mid
index_x_u = index_x_d + img_width_model
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_x_u > img_w:
index_x_u = img_w
index_x_d = img_w - img_width_model
index_y_d = j * height_mid
index_y_u = index_y_d + img_height_model
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - img_height_model
list_i_s.append(i)
list_j_s.append(j)
list_x_u.append(index_x_u)
list_x_d.append(index_x_d)
list_y_d.append(index_y_d)
list_y_u.append(index_y_u)
batch_i.append(i)
batch_j.append(j)
batch_x_u.append(index_x_u)
batch_x_d.append(index_x_d)
batch_y_d.append(index_y_d)
batch_y_u.append(index_y_u)
img_patch[batch_indexer] = img[index_y_d:index_y_u,
index_x_d:index_x_u]
batch_indexer += 1
if (batch_indexer == n_batch_inference or
img_patch[batch] = img[index_y_d: index_y_u,
index_x_d: index_x_u]
batch += 1
if (batch == n_batch_inference or
# last batch
i == nxf - 1 and j == nyf - 1):
self.logger.debug("predicting patches on %s", str(img_patch.shape))
label_p_pred = model.predict(img_patch,verbose=0)
seg = np.argmax(label_p_pred, axis=3)
conf = label_p_pred[tuple(np.indices(seg.shape)) + (seg,)]
if thresholding_for_artificial_class:
seg_art = label_p_pred[:, :, :, artificial_class] >= threshold_art_class
indexer_inside_batch = 0
for i_batch, j_batch in zip(list_i_s, list_j_s):
seg_in = seg[indexer_inside_batch]
conf_in = conf[indexer_inside_batch]
for batch in range(batch):
where = np.index_exp[batch_y_d[batch]: batch_y_u[batch],
batch_x_d[batch]: batch_x_u[batch]]
iny = slice(margin if batch_j[batch] else None,
(-margin or None) if batch_j[batch] < nyf - 1 else None)
inx = slice(margin if batch_i[batch] else None,
(-margin or None) if batch_i[batch] < nxf - 1 else None)
inbox = np.index_exp[iny, inx]
prediction[where][inbox] = seg[batch][inbox]
confidence[where][inbox] = conf[batch][inbox]
if thresholding_for_artificial_class:
seg_in_art = seg_art[indexer_inside_batch]
mask_artificial_class[where][inbox] = seg_art[batch][inbox]
index_y_u_in = list_y_u[indexer_inside_batch]
index_y_d_in = list_y_d[indexer_inside_batch]
index_x_u_in = list_x_u[indexer_inside_batch]
index_x_d_in = list_x_d[indexer_inside_batch]
where = np.index_exp[index_y_d_in:index_y_u_in,
index_x_d_in:index_x_u_in]
if (i_batch == 0 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
margin:]
elif (i_batch == 0 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
margin:]
elif (i_batch == 0 and
j_batch != 0 and
j_batch != nyf - 1):
inbox = np.index_exp[margin:-margin or None,
0:-margin or None]
elif (i_batch == nxf - 1 and
j_batch != 0 and
j_batch != nyf - 1):
inbox = np.index_exp[margin:-margin or None,
margin:]
elif (i_batch != 0 and
i_batch != nxf - 1 and
j_batch == 0):
inbox = np.index_exp[0:-margin or None,
margin:-margin or None]
elif (i_batch != 0 and
i_batch != nxf - 1 and
j_batch == nyf - 1):
inbox = np.index_exp[margin:,
margin:-margin or None]
else:
inbox = np.index_exp[margin:-margin or None,
margin:-margin or None]
prediction[where][inbox] = seg_in[inbox]
confidence[where][inbox] = conf_in[inbox]
if thresholding_for_artificial_class:
mask_artificial_class[where][inbox] = seg_in_art[inbox]
indexer_inside_batch += 1
list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []
batch_indexer = 0
batch_i = []
batch_j = []
batch_x_u = []
batch_x_d = []
batch_y_u = []
batch_y_d = []
batch = 0
img_patch[:] = 0
if thresholding_for_artificial_class: