mirror of
https://github.com/qurator-spk/eynollah.git
synced 2026-03-02 05:11:57 +01:00
training: fix data representation in 7888fa5…
(Eynollah models expet BGR/float instead of RGB/int)
This commit is contained in:
parent
a9496bbc70
commit
56833b3f55
1 changed files with 26 additions and 11 deletions
|
|
@ -435,6 +435,19 @@ def run(_config,
|
||||||
sparse_y_true=False,
|
sparse_y_true=False,
|
||||||
sparse_y_pred=False)])
|
sparse_y_pred=False)])
|
||||||
|
|
||||||
|
def _to_cv2float(img):
|
||||||
|
# rgb→bgr and uint8→float, as expected by Eynollah models
|
||||||
|
return tf.cast(tf.reverse(img, [-1]), tf.float32) / 255
|
||||||
|
def _to_intrgb(img):
|
||||||
|
# bgr→rgb and float→uint8 for plotting
|
||||||
|
return tf.reverse(tf.cast(img * 255, tf.uint8), [-1])
|
||||||
|
def _to_categorical(seg):
|
||||||
|
seg = tf.cast(seg * 255, tf.int8)
|
||||||
|
# gt_gen_utils/pagexml2label uses peculiar pseudo-RGB/index colors
|
||||||
|
#seg = tf.image.rgb_to_grayscale(seg)
|
||||||
|
seg = tf.gather(seg, [0], axis=-1)
|
||||||
|
seg = tf.squeeze(seg, axis=-1)
|
||||||
|
return one_hot(seg, n_classes)
|
||||||
def get_dataset(dir_imgs, dir_labs, shuffle=None):
|
def get_dataset(dir_imgs, dir_labs, shuffle=None):
|
||||||
gen_kwargs = dict(labels=None,
|
gen_kwargs = dict(labels=None,
|
||||||
label_mode=None,
|
label_mode=None,
|
||||||
|
|
@ -452,25 +465,27 @@ def run(_config,
|
||||||
)
|
)
|
||||||
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
|
img_gen = image_dataset_from_directory(dir_imgs, **gen_kwargs)
|
||||||
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
|
lab_gen = image_dataset_from_directory(dir_labs, **gen_kwargs)
|
||||||
|
img_gen = img_gen.map(_to_cv2float)
|
||||||
|
lab_gen = lab_gen.map(_to_cv2float)
|
||||||
if task in ["segmentation", "binarization"]:
|
if task in ["segmentation", "binarization"]:
|
||||||
@tf.function
|
lab_gen = lab_gen.map(_to_categorical)
|
||||||
def to_categorical(seg):
|
|
||||||
seg = tf.image.rgb_to_grayscale(seg)
|
|
||||||
seg = tf.cast(seg, tf.int8)
|
|
||||||
seg = tf.squeeze(seg, axis=-1)
|
|
||||||
return one_hot(seg, n_classes)
|
|
||||||
lab_gen = lab_gen.map(to_categorical)
|
|
||||||
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
|
return tf.data.Dataset.zip(img_gen, lab_gen).rebatch(n_batch, drop_remainder=True)
|
||||||
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
train_gen = get_dataset(dir_flow_train_imgs, dir_flow_train_labels, shuffle=np.random.randint(1e6))
|
||||||
val_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
|
||||||
|
|
||||||
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
callbacks = [TensorBoard(os.path.join(dir_output, 'logs'), write_graph=False),
|
||||||
SaveWeightsAfterSteps(0, dir_output, _config)]
|
SaveWeightsAfterSteps(0, dir_output, _config)]
|
||||||
|
valdn_gen = get_dataset(dir_flow_eval_imgs, dir_flow_eval_labels)
|
||||||
|
train_steps = len(os.listdir(dir_flow_train_imgs)) // n_batch
|
||||||
|
valdn_steps = len(os.listdir(dir_flow_eval_imgs)) // n_batch
|
||||||
|
_log.info("training on %d batches in %d epochs", train_steps, n_epochs)
|
||||||
|
_log.info("validating on %d batches", valdn_steps)
|
||||||
|
|
||||||
if save_interval:
|
if save_interval:
|
||||||
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
callbacks.append(SaveWeightsAfterSteps(save_interval, dir_output, _config))
|
||||||
model.fit(
|
model.fit(
|
||||||
train_gen.prefetch(tf.data.AUTOTUNE), # .repeat()??
|
train_gen.prefetch(tf.data.AUTOTUNE),
|
||||||
validation_data=val_gen.prefetch(tf.data.AUTOTUNE),
|
steps_per_epoch=train_steps,
|
||||||
|
validation_data=valdn_gen.prefetch(tf.data.AUTOTUNE),
|
||||||
|
validation_steps=valdn_steps,
|
||||||
verbose=1,
|
verbose=1,
|
||||||
epochs=n_epochs,
|
epochs=n_epochs,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue