diff --git a/src/eynollah/training/train.py b/src/eynollah/training/train.py index f6117f7..4d0b317 100644 --- a/src/eynollah/training/train.py +++ b/src/eynollah/training/train.py @@ -435,6 +435,19 @@ def run(_config, sparse_y_true=False, sparse_y_pred=False)]) + def _to_cv2float(img): + # rgb→bgr and uint8→float, as expected by Eynollah models + return tf.cast(tf.reverse(img, [-1]), tf.float32) / 255 + def _to_intrgb(img): + # bgr→rgb and float→uint8 for plotting + return tf.reverse(tf.cast(img * 255, tf.uint8), [-1]) + def _to_categorical(seg): + seg = tf.cast(seg * 255, tf.int8) + # gt_gen_utils/pagexml2label uses peculiar pseudo-RGB/index colors + #seg = tf.image.rgb_to_grayscale(seg) + seg = tf.gather(seg, [0], axis=-1) + seg = tf.squeeze(seg, axis=-1) + return one_hot(seg, n_classes) def get_dataset(dir_imgs, dir_labs, shuffle=None): gen_kwargs = dict(labels=None, label_mode=None, @@ -452,25 +465,27 @@ def run(_config, ) img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs) lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs) + img_gen = img_gen.map(_to_cv2float) + lab_gen = lab_gen.map(_to_cv2float) if task in ["segmentation", "binarization"]: - @tf.function - def to_categorical(seg): - seg = tf.image.rgb_to_grayscale(seg) - seg = tf.cast(seg, tf.int8) - seg = tf.squeeze(seg, axis=-1) - return one_hot(seg, n_classes) - lab_gen = lab_gen.map(to_categorical) + lab_gen = lab_gen.map(_to_categorical) return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True) train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6)) - val_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels) - callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), SaveWeightsAfterSteps(0, dir_output, _config)] + valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels) + train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch + valdn_steps = len(os.listdir(dir_flow_eval_imgs)) // n_batch + _log.info("training on %d batches in %d epochs", train_steps, n_epochs) + _log.info("validating on %d batches", valdn_steps) + if save_interval: callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) model.fit( - train_gen.prefetch(tf.data.AUTOTUNE), # .repeat()?? - validation_data=val_gen.prefetch(tf.data.AUTOTUNE), + train_gen.prefetch(tf.data.AUTOTUNE), + steps_per_epoch=train_steps, + validation_data=valdn_gen.prefetch(tf.data.AUTOTUNE), + validation_steps=valdn_steps, verbose=1, epochs=n_epochs, callbacks=callbacks,