training: fix data representation in 7888fa5

(Eynollah models expet BGR/float instead of RGB/int)
This commit is contained in:
Robert Sachunsky 2026-02-24 16:46:19 +01:00
parent a9496bbc70
commit 56833b3f55

View file

@ -435,6 +435,19 @@ def run(_config,
sparse_y_true=False, sparse_y_true=False,
sparse_y_pred=False)]) sparse_y_pred=False)])
def _to_cv2float(img):
# rgb→bgr and uint8→float, as expected by Eynollah models
return tf.cast(tf.reverse(img, [-1]), tf.float32) / 255
def _to_intrgb(img):
# bgr→rgb and float→uint8 for plotting
return tf.reverse(tf.cast(img * 255, tf.uint8), [-1])
def _to_categorical(seg):
seg = tf.cast(seg * 255, tf.int8)
# gt_gen_utils/pagexml2label uses peculiar pseudo-RGB/index colors
#seg = tf.image.rgb_to_grayscale(seg)
seg = tf.gather(seg, [0], axis=-1)
seg = tf.squeeze(seg, axis=-1)
return one_hot(seg, n_classes)
def get_dataset(dir_imgs, dir_labs, shuffle=None): def get_dataset(dir_imgs, dir_labs, shuffle=None):
gen_kwargs = dict(labels=None, gen_kwargs = dict(labels=None,
label_mode=None, label_mode=None,
@ -452,25 +465,27 @@ def run(_config,
) )
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs) img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs) lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
img_gen = img_gen.map(_to_cv2float)
lab_gen = lab_gen.map(_to_cv2float)
if task in ["segmentation", "binarization"]: if task in ["segmentation", "binarization"]:
@tf.function lab_gen = lab_gen.map(_to_categorical)
def to_categorical(seg):
seg = tf.image.rgb_to_grayscale(seg)
seg = tf.cast(seg, tf.int8)
seg = tf.squeeze(seg, axis=-1)
return one_hot(seg, n_classes)
lab_gen = lab_gen.map(to_categorical)
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True) return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6)) train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
val_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False), callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
SaveWeightsAfterSteps(0, dir_output, _config)] SaveWeightsAfterSteps(0, dir_output, _config)]
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch
valdn_steps = len(os.listdir(dir_flow_eval_imgs)) // n_batch
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
_log.info("validating on %d batches", valdn_steps)
if save_interval: if save_interval:
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config)) callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
model.fit( model.fit(
train_gen.prefetch(tf.data.AUTOTUNE), # .repeat()?? train_gen.prefetch(tf.data.AUTOTUNE),
validation_data=val_gen.prefetch(tf.data.AUTOTUNE), steps_per_epoch=train_steps,
validation_data=valdn_gen.prefetch(tf.data.AUTOTUNE),
validation_steps=valdn_steps,
verbose=1, verbose=1,
epochs=n_epochs, epochs=n_epochs,
callbacks=callbacks, callbacks=callbacks,