mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-02-20 16:32:03 +01:00
training.utils.data_gen: avoid repeated array allocation
This commit is contained in:
parent
514a897dd5
commit
83c2408192
1 changed files with 10 additions and 10 deletions
|
|
@ -600,10 +600,9 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c
|
||||||
c = 0
|
c = 0
|
||||||
n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images
|
n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images
|
||||||
random.shuffle(n)
|
random.shuffle(n)
|
||||||
|
img = np.zeros((batch_size, input_height, input_width, 3), dtype=float)
|
||||||
|
mask = np.zeros((batch_size, input_height, input_width, n_classes), dtype=float)
|
||||||
while True:
|
while True:
|
||||||
img = np.zeros((batch_size, input_height, input_width, 3)).astype('float')
|
|
||||||
mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float')
|
|
||||||
|
|
||||||
for i in range(c, c + batch_size): # initially from 0 to 16, c = 0.
|
for i in range(c, c + batch_size): # initially from 0 to 16, c = 0.
|
||||||
try:
|
try:
|
||||||
filename = os.path.splitext(n[i])[0]
|
filename = os.path.splitext(n[i])[0]
|
||||||
|
|
@ -612,21 +611,22 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c
|
||||||
train_img = cv2.resize(train_img, (input_width, input_height),
|
train_img = cv2.resize(train_img, (input_width, input_height),
|
||||||
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
|
interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize
|
||||||
|
|
||||||
img[i - c] = train_img # add to array - img[0], img[1], and so on.
|
img[i - c, :] = train_img # add to array - img[0], img[1], and so on.
|
||||||
if task == "segmentation" or task=="binarization":
|
if task == "segmentation" or task=="binarization":
|
||||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
|
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')
|
||||||
train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width,
|
train_mask = resize_image(train_mask, input_height, input_width)
|
||||||
n_classes)
|
train_mask = get_one_hot(train_mask, input_height, input_width, n_classes)
|
||||||
elif task == "enhancement":
|
elif task == "enhancement":
|
||||||
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255.
|
train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255.
|
||||||
train_mask = resize_image(train_mask, input_height, input_width)
|
train_mask = resize_image(train_mask, input_height, input_width)
|
||||||
|
|
||||||
# train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
# train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3]
|
||||||
|
|
||||||
mask[i - c] = train_mask
|
mask[i - c, :] = train_mask
|
||||||
except:
|
except Exception as e:
|
||||||
img[i - c] = np.ones((input_height, input_width, 3)).astype('float')
|
print(str(e))
|
||||||
mask[i - c] = np.zeros((input_height, input_width, n_classes)).astype('float')
|
img[i - c, :] = 1.
|
||||||
|
mask[i - c, :] = 0.
|
||||||
|
|
||||||
c += batch_size
|
c += batch_size
|
||||||
if c + batch_size >= len(os.listdir(img_folder)):
|
if c + batch_size >= len(os.listdir(img_folder)):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue