diff --git a/src/eynollah/training/utils.py b/src/eynollah/training/utils.py index 56d6bdf..a03d539 100644 --- a/src/eynollah/training/utils.py +++ b/src/eynollah/training/utils.py @@ -600,10 +600,9 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c c = 0 n = [f for f in os.listdir(img_folder) if not f.startswith('.')] # os.listdir(img_folder) #List of training images random.shuffle(n) + img = np.zeros((batch_size, input_height, input_width, 3), dtype=float) + mask = np.zeros((batch_size, input_height, input_width, n_classes), dtype=float) while True: - img = np.zeros((batch_size, input_height, input_width, 3)).astype('float') - mask = np.zeros((batch_size, input_height, input_width, n_classes)).astype('float') - for i in range(c, c + batch_size): # initially from 0 to 16, c = 0. try: filename = os.path.splitext(n[i])[0] @@ -612,21 +611,22 @@ def data_gen(img_folder, mask_folder, batch_size, input_height, input_width, n_c train_img = cv2.resize(train_img, (input_width, input_height), interpolation=cv2.INTER_NEAREST) # Read an image from folder and resize - img[i - c] = train_img # add to array - img[0], img[1], and so on. + img[i - c, :] = train_img # add to array - img[0], img[1], and so on. if task == "segmentation" or task=="binarization": train_mask = cv2.imread(mask_folder + '/' + filename + '.png') - train_mask = get_one_hot(resize_image(train_mask, input_height, input_width), input_height, input_width, - n_classes) + train_mask = resize_image(train_mask, input_height, input_width) + train_mask = get_one_hot(train_mask, input_height, input_width, n_classes) elif task == "enhancement": train_mask = cv2.imread(mask_folder + '/' + filename + '.png')/255. train_mask = resize_image(train_mask, input_height, input_width) # train_mask = train_mask.reshape(224, 224, 1) # Add extra dimension for parity with train_img size [512 * 512 * 3] - mask[i - c] = train_mask - except: - img[i - c] = np.ones((input_height, input_width, 3)).astype('float') - mask[i - c] = np.zeros((input_height, input_width, n_classes)).astype('float') + mask[i - c, :] = train_mask + except Exception as e: + print(str(e)) + img[i - c, :] = 1. + mask[i - c, :] = 0. c += batch_size if c + batch_size >= len(os.listdir(img_folder)):